diff --git a/.gitignore b/.gitignore index 294863ce8ac98840299ea4dfcb8d78ddb8249eb1..cdf5b64dece05d4fe72e023c9859a60f76124497 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,4 @@ build* *~ # GDB temporary files -.gdb_history \ No newline at end of file +.gdb_history diff --git a/CMakeLists.txt b/CMakeLists.txt index b7ad225e2bdd883ae408e4502ba8fd4870f4954c..3c59f574874d99cadf36e0b886103e5f6dfb49da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,8 @@ add_definitions(-DCK_NOGPU) endif() if(NOT CK_NOGPU) -find_package(ROCM REQUIRED PATHS /opt/rocm) +set(ROCM_SYMLINK_LIBS OFF) +find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm) include(ROCMInstallTargets) include(ROCMPackageConfigHelpers) @@ -22,7 +23,7 @@ include(ROCMInstallSymlinks) include(ROCMCreatePackage) include(CheckCXXCompilerFlag) -rocm_setup_version(VERSION 1.0.0) +rocm_setup_version(VERSION 0.2.0) include(TargetFlags) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) endif() @@ -84,19 +85,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) endif() message(STATUS "Build with HIP ${HIP_VERSION}") - -rocm_create_package( - NAME composablekernel - DESCRIPTION "High Performance Composable Kernel for AMD GPUs" - MAINTAINER "MIOpen Kernels Dev Team " - LDCONFIG -) -endif() - -## half -set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/include/half") -message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") - ## tidy include(EnableCompilerWarnings) set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) @@ -250,7 +238,6 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_BINARY_DIR}/include ${PROJECT_SOURCE_DIR}/library/include ) @@ -264,6 +251,11 @@ message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +rocm_package_setup_component(tests + LIBRARY_NAME composablekernel + PACKAGE_NAME tests # Prevent -static suffix on package name +) + add_subdirectory(library) add_subdirectory(example) add_subdirectory(test) @@ -285,8 +277,19 @@ configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in NO_CHECK_REQUIRED_COMPONENTS_MACRO ) -install(FILES +rocm_install(FILES "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) + +set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") +set(CPACK_RPM_PACKAGE_LICENSE "MIT") + +rocm_create_package( + NAME composablekernel + DESCRIPTION "High Performance Composable Kernel for AMD GPUs" + MAINTAINER "MIOpen Kernels Dev Team " + LDCONFIG + HEADER_ONLY +) diff --git a/Dockerfile b/Dockerfile index 79c961144a3af60b32a95a876111ab4a870596e1..0d32b52f75ac89b0138810af692d7a8177e38f0e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,3 +88,8 @@ ADD rbuild.ini /rbuild.ini ADD dev-requirements.txt dev-requirements.txt RUN rbuild prepare -s develop -d $PREFIX RUN groupadd -f render + +# Install the new rocm-cmake version +RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install diff --git a/Jenkinsfile b/Jenkinsfile index beac2ea248fb390150661537625211d7e0abbf8b..15be3e540c49aef417b4f5401eb75d67d41c4465 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -7,7 +7,6 @@ def show_node_info() { echo "NODE_NAME = \$NODE_NAME" lsb_release -sd uname -r - cat /sys/module/amdgpu/version ls /opt/ -la """ } @@ -101,7 +100,8 @@ def buildHipClangJob(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { if (params.USE_DOCKERFILE){ try { retimage = docker.build("${image}", dockerArgs + '.') @@ -191,7 +191,8 @@ def runCKProfiler(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { if (params.USE_DOCKERFILE){ try { retimage = docker.build("${image}", dockerArgs + '.') @@ -317,6 +318,7 @@ pipeline { dbsshport = "${dbsshport}" dbsshuser = "${dbsshuser}" dbsshpassword = "${dbsshpassword}" + status_wrapper_creds = "${status_wrapper_creds}" } stages{ stage("Static checks") { @@ -386,7 +388,7 @@ pipeline { agent{ label rocmnode("gfx908")} environment{ setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """ - execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc .. && make -j """ } steps{ buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2fe9a8455efaeda2eab474b2aa038ec2d9e76841 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 9d7b578046a5e11cafae9ac91ac9419dbf02050a..aa1100dd1381907904ecdfb479ec5aa2609c8798 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,13 @@ docker run \ --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +rocm/tensorflow:rocm5.1-tf2.6-dev \ /bin/bash ``` +# Install the new rocm-cmake version +https://github.com/RadeonOpenCompute/rocm-cmake + ## Build ```bash mkdir build && cd build @@ -23,6 +26,7 @@ cmake \ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_INSTALL_PREFIX=${PATH_TO_CK_INSTALL_DIRECTORY} \ .. ``` @@ -34,7 +38,7 @@ Instructions for running each individual examples are under ```example/``` ## Tests ```bash - make -j tests + make -j examples tests make test ``` @@ -44,6 +48,12 @@ Instructions for running each individual examples are under ```example/``` ``` Instructions for running ckProfiler are under ```profiler/``` +## Install CK +```bash +make install +``` + +## Using CK as pre-built kernel library ## Caveat ### Kernel Timing and Verification diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e741192f90b8216e4b3abe32ae8971fb45ddfee --- /dev/null +++ b/client_example/01_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm gemm.cpp) +target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations) diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b7b7a66039b8114dfa10699cf1996383a56e27e --- /dev/null +++ b/client_example/01_gemm/gemm.cpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1064abc8fa813c837d2f85ad61e340517a24e70d --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) +target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbf2e634f0c9aa11d10639e58576988bef7883c3 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DDELayout = Row; +using DDELayout = Row; +using DELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, DDELayout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, DDELayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_matrix_space_size(M, N, StrideE, DELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + DDELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3742e70844b96575e263b22a14b0bb8c4cde7a43 --- /dev/null +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp) +target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations) diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f142937281a712d1004e15a578fc64d6501d473 --- /dev/null +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using BiasDataType = F32; +using CDataType = F16; +using D0DataType = F16; +using ReduceDataType = F32; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op, + const void* p_a, + const void* p_b, + const void* p_bias, + const void* p_d0, + void* p_c, + void* p_mean, + void* p_square_mean, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideD0, + bool time_kernel) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto passOp = PassThrough{}; + auto squareOp = UnarySquareElementOp{}; + auto divOp = UnaryDivElementOp{N}; + + auto argument_ptr = + p_op->MakeArgumentPointer(p_a, + p_b, + p_bias, + {p_d0}, + p_c, + {p_mean, p_square_mean}, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + {&passOp, &passOp, &passOp}, // functor for a, b, c + {&passOp}, // functor for d0 + {&passOp, &squareOp}, // functor for inputs of reduction + {&divOp, &divOp}); // functor for outputs of reduction + + if(p_op->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = p_op->MakeInvokerPointer(); + + // If we evaluate running time of gemm_reduce. The output may wrong. + // Because we need to initialize the reduction tensor before runing the kernel. + // However we run kernel many times for time_kernel = trie without reinitialize the out + // of reduction tensor. + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl; + + return true; + } + + return false; +} + +template +bool RunDeviceNormalize2D(normalize_op_ptr& p_op, + const void* p_x, + const void* p_mean, + const void* p_square_mean, + const void* p_gamma, + const void* p_beta, + void* p_y, + int M, + int N, + int StrideX, + bool time_kernel) +{ + std::array input = {p_x, p_mean, p_square_mean, p_gamma, p_beta}; + std::array output = {p_y}; + auto normalize_functor = ck::tensor_operation::element_wise::Normalize{}; + + auto argument_ptr = p_op->MakeArgumentPointer(input, + output, + {M, N}, + {{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}}, + {{StrideX, 1}}, + ck::tensor_operation::element_wise::Normalize{}); + + if(p_op->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = p_op->MakeInvokerPointer(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl; + + return true; + } + + return false; +} + +int main() +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + ck::index_t StrideD0 = 1024; + + const auto gemm_reduce_ptrs = + ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances< + ADataType, + BDataType, + CDataType, + ALayout, + BLayout, + CLayout>(); + + const auto normalize_ptrs = + ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances< + CDataType, + ReduceDataType, + ReduceDataType, + GammaDataType, + BetaDataType, + LayerNormOutDataType>(); + + std::cout << "found " << gemm_reduce_ptrs.size() + << " gemm_reduceMean_reduceSquareMean instances" << std::endl; + + std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, CLayout{})); + SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M); + SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N); + + bool b_time_kernel = true; + bool b_only_run_first_kernel = true; + + // layernorm => (1) + (2) + // (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c) + // (2). normalize(c, mean, square_mean, gamma, beta) + for(auto& gemm_reduce_ptr : gemm_reduce_ptrs) + { + // run first available kernel + if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr, + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + d0_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideD0, + b_time_kernel)) + { + if(b_only_run_first_kernel) + break; + } + else + { + std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + + for(auto& normalize_ptr : normalize_ptrs) + { + if(RunDeviceNormalize2D(normalize_ptr, + c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + layerNorm_device_buf.GetDeviceBuffer(), + M, + N, + StrideC, + b_time_kernel)) + { + if(b_only_run_first_kernel) + break; + } + else + { + std::cout << normalize_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } +} diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4bc6780f96d2fe4a4912e3c188b4b5155cc162dd --- /dev/null +++ b/client_example/04_contraction/CMakeLists.txt @@ -0,0 +1,6 @@ +add_executable(client_contraction_scale contraction_scale.cpp) +target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_bilinear contraction_bilinear.cpp) +target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations) + diff --git a/client_example/04_contraction/contraction_bilinear.cpp b/client_example/04_contraction/contraction_bilinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b71c51c02620ce62257e3b33a6165a1c8ddda2b1 --- /dev/null +++ b/client_example/04_contraction/contraction_bilinear.cpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Bilinear; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 25) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])}; + + alpha = std::stof(argv[23]); + beta = std::stof(argv[24]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg23 to 24: alpha, beta\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem d_device_buf(sizeof(DDataType) * + f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/04_contraction/contraction_scale.cpp b/client_example/04_contraction/contraction_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5908c1d86e678796dec3d2616c83e9fca40595fb --- /dev/null +++ b/client_example/04_contraction/contraction_scale.cpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp" + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Scale; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 20) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + scale = std::stof(argv[19]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg19: scale\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{scale}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e04a18599a7b488fb306cbaf598494bd48b69d5 --- /dev/null +++ b/client_example/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.15) +project(ck_app) +add_compile_options(-std=c++17) + +find_package(composable_kernel 1.0.0 COMPONENTS device_operations) +find_package(hip REQUIRED PATHS /opt/rocm) +message(STATUS "Build with HIP ${hip_VERSION}") + +add_subdirectory(01_gemm) +add_subdirectory(02_gemm_add_add_fastgelu) +add_subdirectory(03_gemm_layernorm) +add_subdirectory(04_contraction) diff --git a/client_example/README.md b/client_example/README.md new file mode 100644 index 0000000000000000000000000000000000000000..64a7130d537b1e2fb8752c4031e8430d11a6a46a --- /dev/null +++ b/client_example/README.md @@ -0,0 +1,21 @@ +## +Client application links to CK library, and therefore CK library needs to be installed before building client applications. + + +## Build +```bash +mkdir -p client_example/build +cd client_example/build +``` + +```bash +cmake \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \ +.. +``` + +### Build client example +```bash + make -j +``` diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 959bc4f4b0e26bb9c8e86a68eb34ed692041722c..3718b916ffe43996852507881db281dc5647fef0 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -8,7 +8,7 @@ endif() message(STATUS "Fetching GoogleTest") -list(APPEND GTEST_CMAKE_CXX_FLAGS +list(APPEND GTEST_CMAKE_CXX_FLAGS -Wno-undef -Wno-reserved-identifier -Wno-global-constructors @@ -31,7 +31,11 @@ FetchContent_Declare( # Will be necessary for windows build # set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -FetchContent_MakeAvailable(googletest) +FetchContent_GetProperties(googletest) +if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index 9a22628777c06806990bb9a9972e8d773a7a92f5..0a3060fdc71b22cd655634c7b5d01b00363dffee 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,20 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 32b183a3a160e5ffd05bfda859a8bfaea01bdfd5..d9677da9b9fd6aa2578cb20b3176e5c5d45b0ffd 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,20 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 16c9213104a8f99572dac365c622862ccd10a57f..65206d602f66eb800c783bace5a784fadee0c86a 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,20 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index b126736be656e4c0f90136cd0badc65ae5c491de..0575c0bd9e2fa89a5f8823d7a7796d3d75a50ffd 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,20 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" template using S = ck::Sequence; @@ -83,8 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; int main(int argc, char* argv[]) { @@ -215,24 +221,17 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_f32_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - bf16_to_f32_(a_m_k, a_f32_m_k); - bf16_to_f32_(b_k_n, b_f32_k_n); - bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( - a_f32_m_k, b_f32_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); - return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1; + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 003534f79aa536f8f7aa374baab12c9a2668c06f..0d194403773b1564ba179d924b267b7e91d0e4e9 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,20 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; @@ -27,30 +29,42 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; +// clang-format on + +using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle + // clang-format off +//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on +using DeviceGemmInstance = DeviceGemmInstance0; + using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; @@ -69,7 +83,11 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) + { + // use default case + } + else if(argc == 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -93,7 +111,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); exit(0); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 7cea68c8b0f11858b697c3cacb38f473632f0c61..1b222c971267102dbd3cbd7465aaf82009d6ecd9 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" template using S = ck::Sequence; diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index 27fcd62a2c13b1031ee1fdccd5fafe423ae8227e..4ed1f177db6d0e5df668256f232d631ca9f2464a 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,20 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; diff --git a/example/02_gemm_alpha_beta/CMakeLists.txt b/example/02_gemm_alpha_beta/CMakeLists.txt deleted file mode 100644 index 1b81cf21622b6e70cb43dbd4bc90874fc7bf5580..0000000000000000000000000000000000000000 --- a/example/02_gemm_alpha_beta/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_gemm_xdl_alpha_beta gemm_xdl_alpha_beta.cpp) diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp deleted file mode 100644 index 1a6e1de4dcfb4f75afca02b204e3963dab86b9e7..0000000000000000000000000000000000000000 --- a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp +++ /dev/null @@ -1,253 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm_bias_2d.hpp" - -template -using S = ck::Sequence; - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_2d< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; - - float alpha = 1.0f; - float beta = 1.0f; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - alpha = std::stof(argv[4]); - beta = std::stof(argv[5]); - } - else if(argc == 12) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); - - alpha = std::stof(argv[10]); - beta = std::stof(argv[11]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c0_m_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - AElementOp{}, - BElementOp{}, - CElementOp{alpha, beta}); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(do_verification) - { - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - b_k_n, - c0_m_n, - c_m_n_host_result, - AElementOp{}, - BElementOp{}, - CElementOp{alpha, beta}); - - ref_invoker.Run(ref_argument); - - return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; - } - - return 0; -} diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..10ec0f1a71151668e262efcdbaff7100d2d08dfa --- /dev/null +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/02_gemm_alpha_beta/README.md b/example/02_gemm_bilinear/README.md similarity index 69% rename from example/02_gemm_alpha_beta/README.md rename to example/02_gemm_bilinear/README.md index ba2a3068f3e78757d34f3e9d7f382a76aef19bc5..9eb87e1e3479d72497ec72956b1de649b0ff735f 100644 --- a/example/02_gemm_alpha_beta/README.md +++ b/example/02_gemm_bilinear/README.md @@ -1,11 +1,13 @@ -# Instructions for ```example_gemm_xdl_alpha_beta``` +# Instructions for ```example_gemm_bilinear_xdl_fp16``` -## Run ```example_gemm_xdl_alpha_beta``` +## Run ```example_gemm_bilinear_xdl_fp16``` ```bash #arg1: verification (0=no, 1=yes) #arg2: initialization (0=no init, 1=integer value, 2=decimal value) -#arg3: run kernel # of times (>1) -./bin/example_gemm_xdl_alpha_beta 1 1 1 0.5 0.5 +#arg3: time kernel (0=no, 1=yes) +#arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE +#arg11 to 12: alpha, beta +./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 ``` Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) ``` diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b340807ba6f783b48ab860c6776799d14311649 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index d07ad6e36c3a9f1deda141a66e20945c7fff37c1..35c54abac03094f24187df2503aa02b6812c20f3 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1 +1 @@ -add_example_executable(example_gemm_xdl_bias_relu gemm_xdl_bias_relu.cpp) +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/README.md b/example/03_gemm_bias_relu/README.md index f8d9bd6152907de226567aefc85b91de00238e05..f28a9a071c879e92be34f84054661647c31ebb35 100644 --- a/example/03_gemm_bias_relu/README.md +++ b/example/03_gemm_bias_relu/README.md @@ -1,28 +1,10 @@ -# Instructions for ```example_gemm_xdl_bias_relu_add``` +# Instructions for ```example_gemm_bias_relu_xdl_fp16``` -## Run ```example_gemm_xdl_bias_relu_add``` +## Run ```example_gemm_bias_relu_xdl_fp16``` ```bash #arg1: verification (0=no, 1=yes) #arg2: initialization (0=no init, 1=integer value, 2=decimal value) -#arg3: run kernel # of times (>1) -#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC -./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 -``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -arg.c0_grid_desc_m_n_{ 3840, 4096} -arg.c1_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s +#arg3: time kernel (0=no, 1=yes) +#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE +./bin/example_gemm_bias_relu_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 ``` diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e36280f42db89fe8d24d767365cc4fd40674af4a --- /dev/null +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// C = A * B +// E = Relu(C + D); +struct AddRelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const + { + const ck::half_t x = c + d; + + e = x > 0 ? x : 0; + } +}; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(EDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp deleted file mode 100644 index 3bf3003c147c7107aaa8cb2bda0eed7b1043ee5a..0000000000000000000000000000000000000000 --- a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp +++ /dev/null @@ -1,239 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" -#include "reference_gemm_bias_activation.hpp" - -template -using S = ck::Sequence; - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::AddRelu; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 10) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - // c0_n[n] - Tensor c0_n(HostTensorDescriptor( - std::vector({static_cast(N)}), std::vector({1}))); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_n: " << c0_n.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_n_device_buf.ToDevice(c0_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(c0_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N + sizeof(CDataType) * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(do_verification) - { - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - - return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; - } - - return 0; -} diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..754de47c2b4556c62d7a04714224c29f23d0813f --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/README.md b/example/04_gemm_add_add_fastgelu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..08a55fb9a37f9f7823ae8882d4d36e9bd96ee6fd --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16``` + +## Run ```example_gemm_add_add_fastgelu_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" +./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} +d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> +``` diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4bfbbbadf894d46aa1ef1a0333ba478a0b9a9f08 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD0 = std::stoi(argv[9]); + StrideD1 = std::stoi(argv[10]); + StrideE = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " + "StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(D0DataType) * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/04_gemm_bias_relu_add/CMakeLists.txt b/example/04_gemm_bias_relu_add/CMakeLists.txt deleted file mode 100644 index 4f48db94a889814578edcf34b65958544a3e5278..0000000000000000000000000000000000000000 --- a/example/04_gemm_bias_relu_add/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp) diff --git a/example/04_gemm_bias_relu_add/README.md b/example/04_gemm_bias_relu_add/README.md deleted file mode 100644 index f8d9bd6152907de226567aefc85b91de00238e05..0000000000000000000000000000000000000000 --- a/example/04_gemm_bias_relu_add/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# Instructions for ```example_gemm_xdl_bias_relu_add``` - -## Run ```example_gemm_xdl_bias_relu_add``` -```bash -#arg1: verification (0=no, 1=yes) -#arg2: initialization (0=no init, 1=integer value, 2=decimal value) -#arg3: run kernel # of times (>1) -#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC -./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 -``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -arg.c0_grid_desc_m_n_{ 3840, 4096} -arg.c1_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s -``` diff --git a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp deleted file mode 100644 index 73e92f9d116e023ee4dbaf95279ca4876403001c..0000000000000000000000000000000000000000 --- a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp +++ /dev/null @@ -1,257 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" -#include "reference_gemm_bias_activation_add.hpp" - -template -using S = ck::Sequence; - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on - -using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; - ck::index_t StrideC1 = 4096; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 11) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); - StrideC1 = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - // c0_n[n] - Tensor c0_n(HostTensorDescriptor( - std::vector({static_cast(N)}), std::vector({1}))); - - // c1_m_n[m ,n] - Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_n: " << c0_n.mDesc << std::endl; - std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); - DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_n_device_buf.ToDevice(c0_n.mData.data()); - c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(c0_n_device_buf.GetDeviceBuffer()), - static_cast(c1_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - StrideC1, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N + sizeof(CDataType) * N + - sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(do_verification) - { - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - b_k_n, - c_m_n_host_result, - c0_n, - c1_m_n, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - - return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; - } - - return 0; -} diff --git a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp index d50afb6854c5e09b61432885ecd4470c04b853f5..b3c492fd23fba59bb5578fb47278c2e0883aa41a 100644 --- a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp +++ b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "device.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_conv_fwd_bias_activation.hpp" -#include "tensor_layout.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp" namespace { diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index 1a234ea851937ff3eb210f131f28d8ce74d9e162..7950630adba51aaa793c530849263ce2b6920496 100644 --- a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "device.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_conv_fwd_bias_activation_add.hpp" -#include "tensor_layout.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp" namespace { diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index 2f048097a1cdd66029c3c87dd18f3f18064cc1c9..5866956105f0a2b7b030aa1c0d55b43fad9a5012 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -1,19 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "device.hpp" -#include "device_tensor.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" namespace { @@ -291,8 +294,8 @@ int main(int argc, char* argv[]) float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString() - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv->GetTypeString() << std::endl; if(do_verification) { diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index 7fa0f0d27530bcb2b4255a564f352dfdcd97c894..beb78c3e9b9401bda01fef150c3b6e998f29ba26 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -1,19 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "device.hpp" -#include "device_tensor.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" namespace { diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp index 52440e0d5f112550303079b9c5fe794099db8beb..cf1273fada99df9b7acc5280115ccc28ac559693 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -1,19 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "device.hpp" -#include "device_tensor.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" namespace { diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 9a1028f88b0c472f56e702766a7cfdba7a84ae3a..3ca4b117661da631f7831f7382119870cb7cc8d6 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -1,19 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "device.hpp" -#include "device_tensor.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" namespace { diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp index 2d25f5ac2f180f92553803684891f4a5df0a4afa..340bc657fa52712eaf733c72579c6aa6c09b8fe9 100644 --- a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "element_wise_operation.hpp" -#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" -#include "reference_conv_bwd_data.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; diff --git a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp index 1578161116c02c6bb7c77b26b00c193392274d66..e47ae661520b3c6dd22e46c925fb7371a8631b1b 100644 --- a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp +++ b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "element_wise_operation.hpp" -#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "reference_conv_backward_weight.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md index a6442984e7c72a33880135bd7ae00abc846df326..826d2f6c333d6b2707d16ffa987abf77d0d7c05c 100644 --- a/example/12_reduce/README.md +++ b/example/12_reduce/README.md @@ -5,14 +5,14 @@ # -D : input 4-d tensor lengths # -v : verification (0=no, 1=yes) #arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) -#arg2: time kernel (0=no, 1=yes) +#arg2: time kernel (0=no, 1=yes) ./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 ``` Result ``` ./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 -launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} Warm up 1 time Start running 10 times... Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> @@ -24,19 +24,18 @@ Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr ```bash #arg1: verification (0=no, 1=yes( #arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) -#arg3: time kernel (0=no, 1=yes) +#arg3: time kernel (0=no, 1=yes) ./bin/example_reduce_blockwise_two_call 1 2 1 - +``` Result ``` ./bin/example_reduce_blockwise_two_call 1 2 1 -launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} +launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} Warm up 1 time Start running 10 times... -launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} +launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} Warm up 1 time Start running 10 times... Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> ``` - diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index cc75bbad6044c35f64d8aed798d05612cc543a0d..0a93af535810ca7df5e658f2371cde812d1609ee 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -1,23 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_reduce_multiblock.hpp" -#include "host_common_util.hpp" -#include "host_reduction.hpp" - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/host_tensor/host_common_util.hpp" +#include "ck/library/host_tensor/host_reduction.hpp" using namespace ck; using namespace ck::tensor_operation::device; @@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; constexpr bool PropagateNan = true; constexpr bool OutputIndex = false; -using ReduceOperation = typename reduce_binary_operator::opType; +using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator::AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; using DeviceReduceInstance = DeviceReduceMultiBlock::GetElementwiseOperator( + static_cast(reduce_total_length)); + if(args.do_verification) { ReductionHost hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); - hostReduce.Run( - alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); }; std::vector i_inLengths; @@ -277,20 +289,19 @@ int main(int argc, char* argv[]) auto reduce = DeviceReduceInstance{}; - auto argument_ptr = reduce.MakeArgumentPointer( - i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - nullptr, - out_dev.GetDeviceBuffer(), - out_index_dev.GetDeviceBuffer(), - InElementwiseOperation{static_cast(reduce_total_length)}, - AccElementwiseOperation{static_cast(reduce_total_length)}); + auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths, + i_inStrides, + i_outLengths, + i_outStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_index_dev.GetDeviceBuffer(), + in_elementwise_op, + acc_elementwise_op); if(!reduce.IsSupportedArgument(argument_ptr.get())) { diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index f42fd08f1e10d1277d142e034268d29f30193dfd..727c5877c5ea1f34ba3ade4f097ec7d258594220 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include @@ -5,20 +8,17 @@ #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_reduce_multiblock.hpp" -#include "host_common_util.hpp" -#include "host_reduction.hpp" - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/host_tensor/host_common_util.hpp" +#include "ck/library/host_tensor/host_reduction.hpp" using namespace ck; using namespace ck::tensor_operation::device; @@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; constexpr bool PropagateNan = true; constexpr bool OutputIndex = false; -using ReduceOperation = typename reduce_binary_operator::opType; +using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator::AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; -using PassThroughOp = tensor_operation::element_wise::UnaryIdentic; +using PassThroughOp = tensor_operation::element_wise::PassThrough; using DeviceReduceInstance_1 = DeviceReduceMultiBlock::GetElementwiseOperator( + static_cast(reduce_total_length)); + if(do_verify) { ReductionHost hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims); - hostReduce.Run(alpha, in_1.mData.data(), beta, out_ref.mData.data(), nullptr); + hostReduce.Run(alpha, + in_1.mData.data(), + beta, + out_ref.mData.data(), + nullptr, + in_elementwise_op, + acc_elementwise_op); }; std::vector i_inLengths_1; @@ -217,20 +230,19 @@ int main(int argc, char* argv[]) auto reduce_1 = DeviceReduceInstance_1{}; - auto argument_ptr_1 = reduce_1.MakeArgumentPointer( - i_inLengths_1, - i_inStrides_1, - i_inLengths_2, - i_inStrides_2, - reduceDims_1, - 1.0f, - 0.0f, - in_1_dev.GetDeviceBuffer(), - nullptr, - in_2_dev.GetDeviceBuffer(), - nullptr, - InElementwiseOperation{static_cast(reduce_total_length)}, - PassThroughOp{}); + auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1, + i_inStrides_1, + i_inLengths_2, + i_inStrides_2, + reduceDims_1, + 1.0f, + 0.0f, + in_1_dev.GetDeviceBuffer(), + nullptr, + in_2_dev.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + PassThroughOp{}); if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) { @@ -243,20 +255,19 @@ int main(int argc, char* argv[]) auto reduce_2 = DeviceReduceInstance_2{}; - auto argument_ptr_2 = reduce_2.MakeArgumentPointer( - i_inLengths_2, - i_inStrides_2, - i_outLengths, - i_outStrides, - reduceDims_2, - alpha, - beta, - in_2_dev.GetDeviceBuffer(), - nullptr, - out_dev.GetDeviceBuffer(), - nullptr, - PassThroughOp{}, - AccElementwiseOperation{static_cast(reduce_total_length)}); + auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2, + i_inStrides_2, + i_outLengths, + i_outStrides, + reduceDims_2, + alpha, + beta, + in_2_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + nullptr, + PassThroughOp{}, + acc_elementwise_op); if(!reduce_2.IsSupportedArgument(argument_ptr_2.get())) { diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index 4652ce118957df59bc91287f4a0bc3140ee1187a..ac1d0f3a414508c46608f48194425e75e9db4230 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -1,20 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "reduction_functions_accumulate.hpp" +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" template & in, const std::array& in_left_pads, const std::array& /*in_right_pads*/) { - const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; + const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + using ReduceOperation = typename ck::reduce_binary_operator::opType; - using ReduceOperation = typename ck::reduce_binary_operator::opType; - using InElementwiseOperation = typename ck:: - reduce_unary_operator::InElementwiseOperation; - using AccElementwiseOperation = typename ck:: - reduce_unary_operator::AccElementwiseOperation; + auto elementwise_ops = + ck::reduce_unary_operator::GetElementwiseOperator(reduceLength); - const InElementwiseOperation in_elementwise_op(divider); - const AccElementwiseOperation acc_elementwise_op(divider); + auto in_elementwise_op = std::get<0>(elementwise_ops); + auto acc_elementwise_op = std::get<1>(elementwise_ops); if constexpr(!OutputIndex) { @@ -48,7 +48,7 @@ static void pool_host_verify(const Tensor& in, ck::detail::AccumulateWithNanCheck; auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOperation::GetIdentityValue(); + auto accuVal = ReduceOperation::template GetIdentityValue(); for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) { @@ -86,7 +86,7 @@ static void pool_host_verify(const Tensor& in, AccDataType, IndexDataType>; auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOperation::GetIdentityValue(); + auto accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp index 74507fdfb361094e2a6cb63421527e1052c2f80a..659f3251dcf9f8d7661cf44169458937f0d15088 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -1,9 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include -#include "config.hpp" -#include "tensor_layout.hpp" -#include "reduction_enums.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" #include "pool2d_fwd_common.hpp" diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp index 7ca5b1aab7984b07303f3a7fbf3a8c63aca1259f..f47c7ff15147327603b9eeba815e44465313762c 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp @@ -1,9 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include -#include "config.hpp" -#include "tensor_layout.hpp" -#include "reduction_enums.hpp" +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "pool2d_fwd_common.hpp" diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp index a42df2b7f06c312a35a807b5bfcbc15466340e31..379be22ad14c25acca808d73bd2b1eff91d934c7 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -1,22 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" struct RequantReluRequant { diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 503c87e1381331035c77472c07463ccd6c5421a2..cdb01b180db0adb3193e86819a94c682f3618aee 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,22 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_grouped_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp index 4469130502bdb34aace51837d7b15bd8846b0b94..d20c863c4b8e071f6c3f6ef7e6944ff320d900f1 100644 --- a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp +++ b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp @@ -1,20 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" -#include "element_wise_reduce_operation.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" template using S = ck::Sequence; @@ -31,20 +33,19 @@ using BDataType = F16; using CDataType = F16; using GemmAccDataType = F32; using ReduceAccDataType = F32; -using DDataType = F64; -using DPtrsGlobal = ck::Tuple; +using ReduceDataType = F64; +using ReducePtrsGlobal = ck::Tuple; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using DsReduceOp = ck::Tuple>; -using DsElementOp = ck::Tuple< - ck::tensor_operation::element_wise::UnaryIdentic>; -using DGlobalMemOp = +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOps = ck::Tuple; +using ReduceElementOps = ck::Tuple; +using ReduceGlobalMemOps = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmSpecialization = @@ -52,11 +53,11 @@ static constexpr auto GemmSpecialization = // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -template +template void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) { std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + sizeof(DDataType) * M; + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M; float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; @@ -147,17 +148,17 @@ int main(int argc, char* argv[]) Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d_m_host_result( + Tensor reduce_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d_m_device_result( + Tensor reduce_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; + std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl; switch(init_method) { @@ -175,35 +176,40 @@ int main(int argc, char* argv[]) DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce_device_buf(sizeof(ReduceDataType) * + reduce_m_device_result.mDesc.GetElementSpace()); a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto ds_element_op = DsElementOp{}; - auto p_ds_global = ck::make_tuple(static_cast(d_device_buf.GetDeviceBuffer())); + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}]; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + std::array reduce_element_ops = {&reduce_element_op}; + std::array p_reduces = {reduce_device_buf.GetDeviceBuffer()}; // do GEMM auto gemm = DeviceGemmReduceInstance{}; auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - p_ds_global, + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, M, N, K, StrideA, StrideB, StrideC, - a_element_op, - b_element_op, - c_element_op, - ds_element_op, - ds_element_op); + {}, + gemm_element_ops, + {}, + reduce_element_ops, + reduce_element_ops); if(!gemm.IsSupportedArgument(argument)) { @@ -214,7 +220,7 @@ int main(int argc, char* argv[]) // [CAUSION]: launch_and_time_kernel will not initialize D. // If we evaluate kernel multiple time but without initialize D. Verification will fail - d_device_buf.SetValue(ck::NumericLimits::Lowest()); + reduce_device_buf.SetValue(ck::NumericLimits::Lowest()); invoker.Run(argument, StreamConfig{nullptr, false}); bool pass = true; @@ -222,7 +228,7 @@ int main(int argc, char* argv[]) if(do_verification) { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d_device_buf.FromDevice(d_m_device_result.mData.data()); + reduce_device_buf.FromDevice(reduce_m_device_result.mData.data()); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -232,23 +238,27 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}]; + auto reduce_op = ReduceOps{}[ck::Number<0>{}]; for(int m = 0; m < M; ++m) { - ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue(); + ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue(); for(int n = 0; n < N; ++n) - d_reduce_op(d_acc, c_m_n_host_result(m, n)); + { + ReduceAccDataType curr_val = + ck::type_convert(c_m_n_host_result(m, n)); + reduce_op(reduce_acc, curr_val); + }; - d_m_host_result(m) = d_acc; + reduce_m_host_result(m) = reduce_acc; } pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c") && - ck::utils::check_err(d_m_device_result.mData, - d_m_host_result.mData, + ck::utils::check_err(reduce_m_device_result.mData, + reduce_m_host_result.mData, "Error: Incorrect results d", 1e-3, 1e-3); @@ -258,7 +268,7 @@ int main(int argc, char* argv[]) { float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); - DumpGemmLayerNormPerf( + DumpGemmLayerNormPerf( gemm_reduceMax_ave_time, M, N, K); } diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp index e73e61c5325d2cc3297407630155efa1bc1ba88f..ddfaa9d7522a89f70ece49b980d7af831b3baab0 100644 --- a/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp +++ b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp @@ -1,21 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" -#include "reduction_operator.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" template using S = ck::Sequence; @@ -31,30 +33,27 @@ using BDataType = F16; using CDataType = F16; using GemmAccDataType = F32; using ReduceAccDataType = F32; -using DDataType = F32; -using DPtrsGlobal = ck::Tuple; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using D0ReduceOp = ck::reduce::Add; -using D1ReduceOp = ck::reduce::Add; -using DxsReduceOp = ck::Tuple; - -using UnaryIdenticElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; -using UnaryDivElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; -using UnarySquareElementOp = - ck::tensor_operation::element_wise::UnarySquare; -using DxsInElementOp = ck::Tuple; -using DxsOutElementOp = ck::Tuple; - -using DGlobalMemOp = +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOp0 = ck::reduce::Add; +using ReduceOp1 = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = ck::InMemoryDataOperationEnumSequence; @@ -63,11 +62,11 @@ static constexpr auto GemmSpecialization = // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -template +template void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) { std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + sizeof(DDataType) * M + - sizeof(DDataType) * M; + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; @@ -159,22 +158,22 @@ int main(int argc, char* argv[]) Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d0_m_host_result( + Tensor reduce0_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d1_m_host_result( + Tensor reduce1_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d0_m_device_result( + Tensor reduce0_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d1_m_device_result( + Tensor reduce1_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; - std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; switch(init_method) { @@ -192,39 +191,48 @@ int main(int argc, char* argv[]) DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); - DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpace()); a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer())); + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; - auto dxs_in_element_op = DxsInElementOp{}; - auto dxs_out_element_op = DxsOutElementOp{M, M}; + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; // do GEMM auto gemm = DeviceGemmReduceInstance{}; auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - dxs_global, + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, M, N, K, StrideA, StrideB, StrideC, - a_element_op, - b_element_op, - c_element_op, - dxs_in_element_op, - dxs_out_element_op); + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); if(!gemm.IsSupportedArgument(argument)) { @@ -233,9 +241,9 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); + // init reducetion buffer to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // will not be correct. need to set time_kernel = false for correctness test @@ -245,8 +253,8 @@ int main(int argc, char* argv[]) if(do_verification) { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d0_device_buf.FromDevice(d0_m_device_result.mData.data()); - d1_device_buf.FromDevice(d1_m_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -256,42 +264,40 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - auto d0_reduce_op = D0ReduceOp{}; - auto d1_reduce_op = D1ReduceOp{}; + auto reduce0_op = ReduceOp0{}; + auto reduce1_op = ReduceOp1{}; for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetIdentityValue(); - float d1_acc = d1_reduce_op.GetIdentityValue(); + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { - float c_val = ck::type_convert(c_m_n_host_result(m, n)); - float d0_val = 0; - float d1_val = 0; - - dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); - dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); - d0_reduce_op(d0_acc, d0_val); - d1_reduce_op(d1_acc, d1_val); + auto c_val = ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType square_c_val; + square(square_c_val, c_val); + + reduce0_op(reduce0_acc, c_val); + reduce1_op(reduce1_acc, square_c_val); } - dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); - dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); - d0_m_host_result(m) = ck::type_convert(d0_acc); - d1_m_host_result(m) = ck::type_convert(d1_acc); + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); } pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c") && - ck::utils::check_err(d0_m_device_result.mData, - d0_m_host_result.mData, + ck::utils::check_err(reduce0_m_device_result.mData, + reduce0_m_host_result.mData, "Error: Incorrect results d0", 1e-4, 1e-5) && - ck::utils::check_err(d1_m_device_result.mData, - d1_m_host_result.mData, + ck::utils::check_err(reduce1_m_device_result.mData, + reduce1_m_host_result.mData, "Error: Incorrect results d1", 1e-3, 1e-5); @@ -301,7 +307,7 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); - DumpGemmLayerNormPerf(ave_time, M, N, K); + DumpGemmLayerNormPerf(ave_time, M, N, K); } return pass ? 0 : 1; diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index 0383197358ad66f35fe520c26d8ef1a406d65073..5e3a87e2e431bcb6aa8f67ae74a9ec6dda0cd8ce 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "config.hpp" -#include "conv_util.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "element_wise_operation.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "reference_conv_bwd_data.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 685762fc13ae6d37b5dc534b70a3e5bba63b6670..53bf671514c0ff779638db64be7a15eca5b9bd39 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -1,20 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "reference_batched_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" template using S = ck::Sequence; @@ -29,28 +31,26 @@ using ADataType = F16; using BDataType = F16; using CDataType = F16; using ReduceAccDataType = F32; -using DDataType = F32; -using DPtrsGlobal = ck::Tuple; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using D0ReduceOp = ck::reduce::Add; -using D1ReduceOp = ck::reduce::Add; -using DxsReduceOp = ck::Tuple; - -using UnaryIdenticElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; -using UnarySquareElementOp = - ck::tensor_operation::element_wise::UnarySquare; -using DxsInElementOp = ck::Tuple; -using DxsOutElementOp = ck::Tuple; - -using DGlobalMemOp = +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOp0 = ck::reduce::Add; +using ReduceOp1 = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = ck::InMemoryDataOperationEnumSequence; @@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: @@ -143,16 +143,16 @@ int main(int argc, char* argv[]) Tensor c_g_m_n_host_result( f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); - Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); Tensor c_g_m_n_device_result( f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); - Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; @@ -177,38 +177,48 @@ int main(int argc, char* argv[]) DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); - DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + d1_g_m_device_result.mDesc.GetElementSpace()); a_device_buf.ToDevice(a_g_m_k.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data()); - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer())); + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&passthrough, &passthrough}; + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; // do GEMM auto batched_gemm = DeviceBatchedGemmReduceInstance{}; auto invoker = batched_gemm.MakeInvoker(); - auto argument = - batched_gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - dxs_global, - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - DxsInElementOp{}, - DxsOutElementOp{}, - BatchCount); + auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops, + BatchCount); if(!batched_gemm.IsSupportedArgument(argument)) { @@ -218,8 +228,8 @@ int main(int argc, char* argv[]) } // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // will not be correct. need to set time_kernel = false for correctness test @@ -241,8 +251,8 @@ int main(int argc, char* argv[]) if(do_verification) { c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); - d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); - d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_invoker = ref_batched_gemm.MakeInvoker(); @@ -252,30 +262,31 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - auto d0_reduce_op = D0ReduceOp{}; - auto d1_reduce_op = D1ReduceOp{}; + auto reduce0_op = ReduceOp0{}; + auto reduce1_op = ReduceOp1{}; for(int batch = 0; batch < BatchCount; ++batch) { for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetIdentityValue(); - float d1_acc = d1_reduce_op.GetIdentityValue(); + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { - float c_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); - float d0_val = 0; - float d1_val = 0; + auto c_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; UnaryIdenticElementOp{}(d0_val, c_val); UnarySquareElementOp{}(d1_val, c_val); - d0_reduce_op(d0_acc, d0_val); - d1_reduce_op(d1_acc, d1_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); } - d0_g_m_host_result(batch, m) = ck::type_convert(d0_acc); - d1_g_m_host_result(batch, m) = ck::type_convert(d1_acc); + d0_g_m_host_result(batch, m) = ck::type_convert(reduce0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(reduce1_acc); } } diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp index 54557b6e7e8af02024ef02f02f50752a104029d5..aecd84cb8dae115df61f24ae0f7597f3bf4c2c21 100644 --- a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -1,39 +1,17 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "binary_element_wise_operation.hpp" -#include "device_binary_elementwise.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" using F16 = ck::half_t; using F32 = float; @@ -42,8 +20,7 @@ using ABDataType = F16; using CDataType = F16; using EltwiseComputeDataType = F32; -using Add = ck::tensor_operation::binary_element_wise:: - Add; +using Add = ck::tensor_operation::element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device::DeviceBinaryElementwise input = {a_m_n_device_buf.GetDeviceBuffer(), + b_n_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_n_device_buf.GetDeviceBuffer()}; + + std::vector a_strides = {Stride, 1}; + std::vector b_strides = {0, 1}; + std::vector c_strides = {Stride, 1}; + auto broadcastAdd = DeviceElementwiseAddInstance{}; - auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(), - b_n_device_buf.GetDeviceBuffer(), - c_m_n_device_buf.GetDeviceBuffer(), - {M, N}, - {Stride, 1}, - {0, 1}, // broadcast in first dimension - {Stride, 1}, - Add{}); + auto argument = broadcastAdd.MakeArgumentPointer( + input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp index ba02e459399e521375fdeaf221b801b5eb29a06e..89def92d2626a6ac2a36b015d0628840067be4d4 100644 --- a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -1,14 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "binary_element_wise_operation.hpp" -#include "device_binary_elementwise.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" using F16 = ck::half_t; using F32 = float; @@ -17,8 +20,7 @@ using ABDataType = F16; using CDataType = F16; using EltwiseComputeDataType = F32; -using Add = ck::tensor_operation::binary_element_wise:: - Add; +using Add = ck::tensor_operation::element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device::DeviceBinaryElementwise input = {a_m_device_buf.GetDeviceBuffer(), + b_m_n_k_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_n_k_device_buf.GetDeviceBuffer()}; + + std::vector a_strides = {1, 0, 0}; + std::vector b_strides{b_m_n_k.mDesc.GetStrides().begin(), + b_m_n_k.mDesc.GetStrides().end()}; + std::vector c_strides{c_m_n_k.mDesc.GetStrides().begin(), + c_m_n_k.mDesc.GetStrides().end()}; + auto broadcastAdd = DeviceElementwiseAddInstance{}; - auto argument = broadcastAdd.MakeArgumentPointer( - a_m_device_buf.GetDeviceBuffer(), - b_m_n_k_device_buf.GetDeviceBuffer(), - c_m_n_k_device_buf.GetDeviceBuffer(), - std::vector{mnk.begin(), mnk.end()}, - {1, 0, 0}, // broadcast A on second and third dimension - std::vector{b_m_n_k.mDesc.GetStrides().begin(), - b_m_n_k.mDesc.GetStrides().end()}, - std::vector{c_m_n_k.mDesc.GetStrides().begin(), - c_m_n_k.mDesc.GetStrides().end()}, - Add{}); + auto argument = + broadcastAdd.MakeArgumentPointer(input, + output, + std::vector{mnk.begin(), mnk.end()}, + {a_strides, b_strides}, + {c_strides}, + Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index c9791b1cb61018df6ba7a5f9b1c60a00a74a8e4a..aab60146a334012296757a1b315cbb655b2574a0 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -1,39 +1,16 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "binary_element_wise_operation.hpp" -#include "device_binary_elementwise.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" using F16 = ck::half_t; using F32 = float; @@ -42,8 +19,7 @@ using ABDataType = F16; using CDataType = F16; using EltwiseComputeDataType = F32; -using Add = ck::tensor_operation::binary_element_wise:: - Add; +using Add = ck::tensor_operation::element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device::DeviceBinaryElementwise input = {a_m_device_buf.GetDeviceBuffer(), + b_m_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_device_buf.GetDeviceBuffer()}; + + std::vector a_strides = {1}; + std::vector b_strides = {1}; + std::vector c_strides = {1}; + auto broadcastAdd = DeviceElementwiseAddInstance{}; - auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(), - b_m_device_buf.GetDeviceBuffer(), - c_m_device_buf.GetDeviceBuffer(), - {M}, - {1}, - {1}, - {1}, - Add{}); + auto argument = broadcastAdd.MakeArgumentPointer( + input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index 30d7c8066a1f7fdf6f73fd1c9e8bbfd8f7593df0..a4a703a71c33136b735a1bdb4677158d15654015 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -1,39 +1,17 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "binary_element_wise_operation.hpp" -#include "device_binary_elementwise.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" using F16 = ck::half_t; using F32 = float; @@ -42,8 +20,7 @@ using ABDataType = F16; using CDataType = F16; using EltwiseComputeDataType = F32; -using Add = ck::tensor_operation::binary_element_wise:: - Add; +using Add = ck::tensor_operation::element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device::DeviceBinaryElementwise input = {a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer()}; + std::array output = {c_device_buf.GetDeviceBuffer()}; + + std::vector a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}; + std::vector b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}; + std::vector c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()}; + auto broadcastAdd = DeviceElementwiseAddInstance{}; - auto argument = broadcastAdd.MakeArgumentPointer( - a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - c_device_buf.GetDeviceBuffer(), - std::vector{nchw.begin(), nchw.end()}, - std::vector{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, - std::vector{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, - std::vector{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()}, - Add{}); + auto argument = + broadcastAdd.MakeArgumentPointer(input, + output, + std::vector{nchw.begin(), nchw.end()}, + {{a_strides}, b_strides}, + {c_strides}, + Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt index 1a644d947940d1afb2a1b129c37b73ef392de466..66fdef625a788ad7e42fef5e09fbb59ed43f1c18 100644 --- a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt +++ b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt @@ -1,2 +1,4 @@ add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp) -target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) \ No newline at end of file +add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp) +target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) +target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util) \ No newline at end of file diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp index 65725d3ae809b1a77ca6145c8e39da9dd15e02df..e6d64e59646a9caf57442444531df5736a16c97b 100644 --- a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -1,22 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "conv_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "element_wise_operation.hpp" -#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "reference_conv_backward_weight.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -297,52 +297,15 @@ int main(int argc, char* argv[]) split_k); // alloc work space - size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); - float ave_time = 0.f; - if(std::is_same::value && split_k > 1) - { - DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); - wei_work_space_device_buf.SetZero(); - argument = conv->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_work_space_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - params.N_, - params.K_, - params.C_, - params.input_spatial_lengths_, - params.filter_spatial_lengths_, - output_spatial_lengths, - params.conv_filter_strides_, - params.conv_filter_dilations_, - params.input_left_pads_, - params.input_right_pads_, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}, - split_k); - - if(!conv->IsSupportedArgument(argument.get())) - { - std::cout << "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem" - << std::endl; - return 1; - } - - ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - } - else + float ave_time = 0.f; + if(!conv->IsSupportedArgument(argument.get())) { - if(!conv->IsSupportedArgument(argument.get())) - { - std::cout << "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem" - << std::endl; - return 1; - } - ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; } + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = ck::utils::conv::get_flops( params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34377bab94238842a19abce18b87dc7bdd2ba179 --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/device_unary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert; + +using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device:: + DeviceUnaryElementwise; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + AccDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +template +void host_elementwise(HostTensorB& B, + const HostTensorA& A, + const std::vector& shape, + Functor functor) +{ + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl; + for(std::size_t n = 0; n < tensor_size; ++n) + { + B.mData[n] = functor(A.mData[n]); + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k : in this example split-k must be larger than 1\n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 2; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + if(split_k <= 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); + if(bwd_weight_workspace_size <= 0) + { + print_use_msg(); + exit(1); + } + + float conv_ave_time = 0.f; + + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer()); + + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / conv_ave_time; + + float gb_per_sec = num_btype / 1.E6 / conv_ave_time; + + std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 3b854507bc5a088719b6e5ea172873d06a5dea9c..78d3a5d02a5be1a0c7e6475aa080a5c75047ac9b 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1 +1,3 @@ +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp) diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c64cfcf0160c10c468afd3f51fe843d6c8e6f13 --- /dev/null +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp @@ -0,0 +1,432 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using BiasDataType = F32; +using D0DataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::Relu; +using D0ElementOp = PassThrough; +using ReduceSumOp = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = + ck::tensor_operation::device::Device5AryElementwise; // scalarPerVector: LayerNorm_out + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +template +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& bias_n, + const Tensor& c1_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + A_functor a_element_op, + B_functor b_element_op, + C_functor c_element_op, + C1_functor c1_element_op, + int M, + int N) +{ + + int StrideC = N; + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + // c = activation(c + bias) + c1_functor(c1) + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + AccDataType acc = ck::type_convert(c_m_n(m, n)) + + ck::type_convert(bias_n(n)); + + AccDataType c1 = ck::type_convert(c1_m_n(m, n)); + + c_element_op(acc, acc); + c1_element_op(c1, c1); + acc += c1; + c_m_n(m, n) = ck::type_convert(acc); + } + + // reduce_mean and reduce_square_mean + auto reduceSumOpInst = ReduceSumOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = reduceSumOpInst.GetIdentityValue(); + auto square_mean_acc = reduceSumOpInst.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + AccDataType c_val = ck::type_convert(c_m_n(m, n)); + AccDataType square_c_val = 0; + UnarySquareElementOp{}(square_c_val, c_val); + + reduceSumOpInst(mean_acc, c_val); + reduceSumOpInst(square_mean_acc, square_c_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(square_mean_acc, square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + AccDataType out_acc = 0; + layerNormInst(out_acc, + ck::type_convert(c_m_n(m, n)), + ck::type_convert(mean_m(m)), + ck::type_convert(meanSquare_m(m)), + ck::type_convert(gamma_n(n)), + ck::type_convert(beta_n(n))); + out_m_n(m, n) = ck::type_convert(out_acc); + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + ck::index_t StrideD0 = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + bias_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace()); + DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * + reduceMeanSquare_m.mDesc.GetElementSpace()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + d0_device_buf.ToDevice(c1_m_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto d_element_op = D0ElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + std::array p_reduces = {reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer()}; + + // Prepare GEMM, reduce_mean, reduce_mean_square + auto gemmReduce = DeviceGemmBiasAddReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + gemm_element_ops, + {&d_element_op}, + reduce_in_element_ops, + reduce_out_element_ops); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + reduceMean_device_buf.SetZero(); + reduceMeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + std::array input = {c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer()}; + std::array output = {layerNorm_device_buf.GetDeviceBuffer()}; + + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument = normalize.MakeArgument(input, + output, + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument)) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "Device5AryElementwise instance, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + bias_n, + c1_m_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + d_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n.mData, + host_layerNorm_m_n.mData, + "Error: Incorrect results layerNorm_m_n", + 1e-2, + 1e-2); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp index 630f8df1f81fb42ab8d20833440113fdbfc22669..e418eea1a96f71126615370f834d3f762443e2db 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp @@ -1,21 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_5ary_elementwise.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" -#include "element_wise_reduce_operation.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" template using S = ck::Sequence; @@ -31,8 +33,8 @@ using BDataType = F16; using CDataType = F16; using GemmAccDataType = F32; using ReduceAccDataType = F32; -using DDataType = F32; -using DPtrsGlobal = ck::Tuple; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; using GammaDataType = F16; using BetaDataType = F16; using LayerNormOutDataType = F16; @@ -45,19 +47,16 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using ReduceSumOp = ck::reduce::Add; -using DxsReduceOp = ck::Tuple; - -using UnaryIdenticElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; -using UnaryDivElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; -using UnarySquareElementOp = - ck::tensor_operation::element_wise::UnarySquare; -using DxsInElementOp = ck::Tuple; -using DxsOutElementOp = ck::Tuple; - -using DxsGlobalMemOp = +using ReduceSumOp = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = ck::InMemoryDataOperationEnumSequence; @@ -66,11 +65,11 @@ static constexpr auto GemmSpecialization = // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm @@ -141,9 +140,9 @@ void host_gemm_layernorm(Tensor& out_m_n, int StrideC = N; Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); - Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); - Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); - auto averageOpInst = UnaryDivElementOp{M}; + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{N}; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -157,13 +156,14 @@ void host_gemm_layernorm(Tensor& out_m_n, auto reduceSumOpInst = ReduceSumOp{}; for(int m = 0; m < M; ++m) { - float mean_acc = reduceSumOpInst.GetIdentityValue(); - float square_mean_acc = reduceSumOpInst.GetIdentityValue(); + auto mean_acc = reduceSumOpInst.GetIdentityValue(); + auto square_mean_acc = reduceSumOpInst.GetIdentityValue(); for(int n = 0; n < N; ++n) { - ReduceAccDataType c_val = ck::type_convert(c_m_n(m, n)); - ReduceAccDataType square_c_val = 0; + auto c_val = ck::type_convert(c_m_n(m, n)); + auto square_c_val = reduceSumOpInst.GetIdentityValue(); + UnarySquareElementOp{}(square_c_val, c_val); reduceSumOpInst(mean_acc, c_val); @@ -172,8 +172,8 @@ void host_gemm_layernorm(Tensor& out_m_n, averageOpInst(mean_acc, mean_acc); averageOpInst(square_mean_acc, square_mean_acc); - mean_m(m) = ck::type_convert(mean_acc); - meanSquare_m(m) = ck::type_convert(square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); } // LayerNorm @@ -183,7 +183,12 @@ void host_gemm_layernorm(Tensor& out_m_n, for(int n = 0; n < N; ++n) { float out_f32 = 0; - layerNormInst(out_f32, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n)); + layerNormInst(out_f32, + static_cast(c_m_n(m, n)), + static_cast(mean_m(m)), + static_cast(meanSquare_m(m)), + static_cast(gamma_n(n)), + static_cast(beta_n(n))); out_m_n(m, n) = static_cast(out_f32); } } @@ -192,7 +197,7 @@ void host_gemm_layernorm(Tensor& out_m_n, template @@ -200,11 +205,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, { std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + sizeof(DDataType) * M + - sizeof(DDataType) * M; + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; - std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M + - sizeof(DDataType) * M + sizeof(GammaDataType) * N + + std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; @@ -232,8 +237,8 @@ int main() Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); - Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); - Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor layerNorm_m_n( @@ -247,8 +252,8 @@ int main() DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); - DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace()); - DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) * + DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * reduceMeanSquare_m.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); @@ -260,35 +265,40 @@ int main() gamma_device_buf.ToDevice(gamma_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data()); - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto dxs_global = - ck::make_tuple(static_cast(reduceMean_device_buf.GetDeviceBuffer()), - static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer())); + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; - auto dxs_in_element_op = DxsInElementOp{}; - auto dxs_out_element_op = DxsOutElementOp{M, M}; + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + std::array p_reduces = {reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer()}; // Prepare GEMM, reduce_mean, reduce_mean_square - auto gemmReduce = DeviceGemmReduceInstance{}; - auto gemmReduce_invoker = gemmReduce.MakeInvoker(); - auto gemmReduce_argument = - gemmReduce.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - dxs_global, - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - dxs_in_element_op, - dxs_out_element_op); + auto gemmReduce = DeviceGemmReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) { @@ -301,23 +311,25 @@ int main() reduceMeanSquare_device_buf.SetZero(); // Prepare LayerNorm + std::array input = {c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer()}; + std::array output = {layerNorm_device_buf.GetDeviceBuffer()}; + auto normalize = DeviceNormalizeInstance{}; auto normalize_invoker = normalize.MakeInvoker(); - auto normalize_argument = normalize.MakeArgument( - static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(reduceMean_device_buf.GetDeviceBuffer()), - static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer()), - static_cast(gamma_device_buf.GetDeviceBuffer()), - static_cast(beta_device_buf.GetDeviceBuffer()), - static_cast(layerNorm_device_buf.GetDeviceBuffer()), - {M, N}, - {StrideC, 1}, - {1, 0}, - {1, 0}, - {0, 1}, - {0, 1}, - {StrideC, 1}, - NormalizeFunctor{}); + auto normalize_argument = normalize.MakeArgument(input, + output, + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); if(!normalize.IsSupportedArgument(normalize_argument)) { @@ -335,16 +347,16 @@ int main() Tensor host_layerNorm_m_n( f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); - host_gemm_layernorm(host_layerNorm_m_n, - a_m_k, - b_k_n, - gamma_n, - beta_n, - a_element_op, - b_element_op, - c_element_op, - M, - N); + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + M, + N); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); pass &= ck::utils::check_err(layerNorm_m_n.mData, @@ -367,7 +379,7 @@ int main() DumpGemmLayerNormPerf( diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06506cab8e8505235706afeac51fbf220d2b68de --- /dev/null +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel +// +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using C0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +struct Relu +{ + template + __host__ __device__ void operator()(OutT& y, const InT& x) const + { + y = x > 0 ? x : 0; + } +}; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +// Elementwise operation that operates on the output of matrix multiplication +// i.e., AccElementOp(A * B + bias) +using AccElementOp = Relu; +// Elementwise operation that operates on the output of layer normalization +using CElementOp = Relu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle +//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadCopy| +//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>; +// clang-format on + +using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 128; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 128; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_n_bias(HostTensorDescriptor(std::vector({size_t(N)}))); + Tensor c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_n_gamma(HostTensorDescriptor(std::vector({size_t(N)}))); + Tensor c0_n_beta(HostTensorDescriptor(std::vector({size_t(N)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl; + std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl; + std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl; + std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + c0_n_bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_m_n_add.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n_gamma.GenerateTensorValue(GeneratorTensor_2{0, 2}); + c0_n_beta.GenerateTensorValue(GeneratorTensor_2{0, 5}); + c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1{0}); + acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1{0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace()); + DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace()); + DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace()); + DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c0_bias_buf.ToDevice(c0_n_bias.mData.data()); + c0_add_buf.ToDevice(c0_m_n_add.mData.data()); + c0_gamma_buf.ToDevice(c0_n_gamma.mData.data()); + c0_beta_buf.ToDevice(c0_n_beta.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto acc_element_op = AccElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_add_buf.GetDeviceBuffer()), + static_cast(c0_bias_buf.GetDeviceBuffer()), + static_cast(c0_gamma_buf.GetDeviceBuffer()), + static_cast(c0_beta_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div, + // excluding reduction steps + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N; + // extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN) + std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + auto ref_gemm = ReferenceInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n_bias, + c0_m_n_add, + c0_n_gamma, + c0_n_beta, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + if constexpr(std::is_same::value) + { + pass &= ck::utils::check_err( + c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c"); + } + else if constexpr(std::is_same::value) + { + pass &= ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c", + 1e-2, + 1e-2); + } + } + return pass ? 0 : 1; +} diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 9790164e7264b040c04230f93c2f3f58b0fdd352..a1dbf0b6c40ccec8b650c27422b403425f54f1d0 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,45 +1,21 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_cgemm_4gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_cgemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" template using S = ck::Sequence; diff --git a/example/23_softmax/CMakeLists.txt b/example/23_softmax/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..dafe65521aab12d55f3ead8ca1ae4f78f9810bde --- /dev/null +++ b/example/23_softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_softmax_blockwise softmax_blockwise.cpp) \ No newline at end of file diff --git a/example/23_softmax/README.md b/example/23_softmax/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37c43e9b552245cb05a50750ed93f8e730353ff3 --- /dev/null +++ b/example/23_softmax/README.md @@ -0,0 +1,18 @@ +# Instructions for ```example_softmax_blockwise``` + +## Run ```example_softmax_blockwise``` +```bash +# -D : input 3-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +example_softmax_blockwise -D 4,128,2048 -v 1 1 1 +``` + +Result +``` +launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8> +``` diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6df3155e8094865a465099171d9649742632d713 --- /dev/null +++ b/example/23_softmax/softmax_blockwise.cpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +constexpr int Rank = 3; +constexpr int NumReduceDim = 1; + +using DeviceInstance = DeviceSoftmax; // OutScalarPerVector + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {8, 128, 2048}; + std::vector scales = {2.0f, 2.0f}; + + bool do_verification = true; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +int main(int argc, char* argv[]) +{ + // Example: batched gemm C[G, M, N] applies max/sum reduction along N internally + const std::vector invariantDims{0, 1}; + const std::vector reduceDims{2}; + + SimpleAppArgs args; + + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; + + Tensor in(args.inLengths); + Tensor out_ref(args.inLengths); + Tensor out(args.inLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + AccDataType alpha = args.scales[0]; + AccDataType beta = args.scales[1]; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + std::size_t num_thread = 1; + + if(args.do_verification) + { + switch(args.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + // std::cout << "beta = " << beta << std::endl; + // LogRangeAsType(std::cout << "tensor in: " , in.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "tensor prior out: " , out.mData, ",") << std::endl; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + if(args.do_verification) + { + using ReferenceInstance = + tensor_operation::host::ReferenceSoftmax; + ReferenceInstance ref; + auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims); + auto invoker = ref.MakeInvoker(); + invoker.Run(ref_arg); + // LogRangeAsType(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl; + }; + + std::vector i_inLengths; + std::vector i_inStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + + auto device_instance = DeviceInstance{}; + + auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, + i_inStrides, + reduceDims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + return 1; + }; + + std::string instance_name = device_instance.GetTypeString(); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + + bool pass = true; + if(args.do_verification) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + out_dev.FromDevice(out.mData.data()); + // LogRangeAsType(std::cout << "tensor out: " , out.mData, ",") << std::endl; + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + }; + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); + + std::size_t num_bytes = + in.mDesc.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name + << std::endl; + + return (pass ? 0 : 1); +} diff --git a/example/24_batched_gemm_c_permute/CMakeLists.txt b/example/24_batched_gemm_c_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..79c612d05359dfd36f6d91e37153a440bbe94b7f --- /dev/null +++ b/example/24_batched_gemm_c_permute/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_batched_gemm_c_permute_xdl_fp16 batched_gemm_c_permute_xdl_fp16.cpp) + diff --git a/example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp b/example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81a1f7d1d70284988044d8820f913847d9eec142 --- /dev/null +++ b/example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp @@ -0,0 +1,245 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl +//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; + < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + const int M = 88; + const int N = 64; + const int K = 88; + + const int stride_A = K; + const int stride_B = K; + + const int G0 = 1024; + const int G1 = 10; + + const int batch_count = G0 * G1; + + // output layout - [G0, M, G1, N] + const int stride_G0 = M * G1 * N; + const int stride_G1 = N; + const int stride_M = G1 * N; + const int stride_N = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + // GEMM shape + ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{ + G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N}; + + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count_, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count_, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); + + auto f_host_c_tensor_descriptor = [](std::size_t G0_, + std::size_t G1_, + std::size_t M_, + std::size_t N_, + std::size_t stride_G0_, + std::size_t stride_G1_, + std::size_t stride_M_, + std::size_t stride_N_) { + return HostTensorDescriptor( + std::vector({G0_, G1_, M_, N_}), + std::vector({stride_G0_, stride_G1_, stride_M_, stride_N_})); + }; + + Tensor c_g0_g1_m_n_host_result( + f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); + + Tensor c_g0_g1_m_n_device_result( + f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + stride_A, + stride_B, + batched_gemm_c_permute_desc, + a_element_op, + b_element_op, + c_element_op, + batch_count); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(CDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data()); + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + Tensor c_g_m_n_host_result = HostTensorDescriptor( + std::vector({batch_count, M, N}), std::vector({M * N, N, 1})); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int g0 = 0; g0 < G0; g0++) + { + for(int g1 = 0; g1 < G1; g1++) + { + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + int g = g0 * G1 + g1; + c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n); + } + } + } + } + + pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData, + c_g0_g1_m_n_device_result.mData, + "Error: Incorrect results c"); + } + + return pass ? 0 : 1; +} diff --git a/example/25_gemm_bias_c_permute/CMakeLists.txt b/example/25_gemm_bias_c_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..29b1d94b3c7851af70c8c990d03c72ca540d6c88 --- /dev/null +++ b/example/25_gemm_bias_c_permute/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp) diff --git a/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp b/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7a439ca34f5e15e75e9df045fc4148481b659bd --- /dev/null +++ b/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl +//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t M0 = 4; + ck::index_t M1 = 32; + ck::index_t M2 = 128; + ck::index_t N0 = 16; + ck::index_t N1 = 256; + + // GEMM shape + ck::index_t M = M0 * M1 * M2; + ck::index_t N = N0 * N1; + ck::index_t K = 128; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + +#if 1 + // E = [M0, N0, M1, N1, M2] + ck::index_t stride_E_M0 = N0 * M1 * N1 * M2; + ck::index_t stride_E_M1 = N1 * M2; + ck::index_t stride_E_M2 = 1; + ck::index_t stride_E_N0 = M1 * N1 * M2; + ck::index_t stride_E_N1 = M2; + + // D = [0, N0, 0, N1, 0] + ck::index_t stride_D_M0 = 0; + ck::index_t stride_D_M1 = 0; + ck::index_t stride_D_M2 = 0; + ck::index_t stride_D_N0 = N1; + ck::index_t stride_D_N1 = 1; +#else + // D = [0, 0, 0, N0, N1] + ck::index_t stride_D_M0 = 0; + ck::index_t stride_D_M1 = 0; + ck::index_t stride_D_M2 = 0; + ck::index_t stride_D_N0 = N1; + ck::index_t stride_D_N1 = 1; + + // E = [M0, M1, M2, N0, N1] + ck::index_t stride_E_M0 = M1 * M2 * N0 * N1; + ck::index_t stride_E_M1 = M2 * N0 * N1; + ck::index_t stride_E_M2 = N0 * N1; + ck::index_t stride_E_N0 = N1; + ck::index_t stride_E_N1 = 1; +#endif + + const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{ + M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1}; + const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{ + M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + auto f_host_de_tensor_descriptor = + [](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) { + std::size_t m0 = de_grid_desc.M0_; + std::size_t m1 = de_grid_desc.M1_; + std::size_t m2 = de_grid_desc.M2_; + std::size_t n0 = de_grid_desc.N0_; + std::size_t n1 = de_grid_desc.N1_; + std::size_t stride_m0 = de_grid_desc.stride_M0_; + std::size_t stride_m1 = de_grid_desc.stride_M1_; + std::size_t stride_m2 = de_grid_desc.stride_M2_; + std::size_t stride_n0 = de_grid_desc.stride_N0_; + std::size_t stride_n1 = de_grid_desc.stride_N1_; + return HostTensorDescriptor( + std::vector({m0, m1, m2, n0, n1}), + std::vector({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1})); + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + Tensor d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc)); + Tensor e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc)); + Tensor e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc)); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl; + std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) * + d_m0_m1_m2_n0_n1.mDesc.GetElementSpace()); + DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) * + e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), + b_k_n_device_buf.GetDeviceBuffer(), + d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), + e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m0 = 0; m0 < M0; ++m0) + for(int m1 = 0; m1 < M1; ++m1) + for(int m2 = 0; m2 < M2; ++m2) + for(int n0 = 0; n0 < N0; ++n0) + for(int n1 = 0; n1 < N1; ++n1) + { + int m = m0 * M1 * M2 + m1 * M2 + m2; + int n = n0 * N1 + n1; + + cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1), + ck::type_convert(c_m_n(m, n)), + d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1)); + } + + e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data()); + + return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData, + e_m0_m1_m2_n0_n1_host_result.mData) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..87f4750e3bf5dd81c001718c388038a5840d8a88 --- /dev/null +++ b/example/26_contraction/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) diff --git a/example/26_contraction/README.md b/example/26_contraction/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c88d93cf83a411e2499cd0be80e524c42db339c6 --- /dev/null +++ b/example/26_contraction/README.md @@ -0,0 +1,20 @@ +# Instructions for ```example_contraction_bilinear_xdl_fp32``` + +## Run +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_contraction_bilinear_xdl_fp32 1 1 1 +``` + +Result (MI100 @ dynammic freq, 46TFlops peak FP32) +``` +a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} +b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096} +c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4> +``` diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed3f2c0e829e86b5c9c1179f2ca7045eb7a75b0d --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks( + std::vector(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), + std::vector(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); + Tensor b_ns_ks( + std::vector(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()), + std::vector(b_ns_ks_strides.begin(), b_ns_ks_strides.end())); + Tensor d_ms_ns( + std::vector(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()), + std::vector(d_ms_ns_strides.begin(), d_ms_ns_strides.end())); + Tensor e_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + Tensor e_ms_ns_device_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5337c45a7c09276b21502efa8a39d7d54608954 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + scale = std::stof(argv[26]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks( + std::vector(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), + std::vector(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); + Tensor b_ns_ks( + std::vector(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()), + std::vector(b_ns_ks_strides.begin(), b_ns_ks_strides.end())); + Tensor e_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + Tensor e_ms_ns_device_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 158b677cab07defd8a2b48c12cc30c24e93e799d..f9bf1b2ee6b7227c98f627676bba39a5a1e6811a 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,26 +1,6 @@ include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/include/ck - ${PROJECT_SOURCE_DIR}/include/ck/utility - ${PROJECT_SOURCE_DIR}/include/ck/host_utility - ${PROJECT_SOURCE_DIR}/include/ck/tensor_description - ${PROJECT_SOURCE_DIR}/include/ck/tensor - ${PROJECT_SOURCE_DIR}/include/ck/problem_transform - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/grid - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/block - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/thread - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/element - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility - ${PROJECT_SOURCE_DIR}/external/include/half + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include ) add_custom_target(examples) @@ -42,9 +22,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME) add_subdirectory(01_gemm) -add_subdirectory(02_gemm_alpha_beta) +add_subdirectory(02_gemm_bilinear) add_subdirectory(03_gemm_bias_relu) -add_subdirectory(04_gemm_bias_relu_add) +add_subdirectory(04_gemm_add_add_fastgelu) add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(07_conv2d_fwd_bias_relu_add) add_subdirectory(09_convnd_fwd) @@ -61,5 +41,9 @@ add_subdirectory(19_binary_elementwise) add_subdirectory(20_convnd_bwd_weight_xdl) add_subdirectory(21_gemm_layernorm) add_subdirectory(22_cgemm) +add_subdirectory(23_softmax) +add_subdirectory(24_batched_gemm_c_permute) +add_subdirectory(25_gemm_bias_c_permute) +add_subdirectory(26_contraction) add_subdirectory(cpu_01_conv2d_fwd) add_subdirectory(cpu_02_conv2d_fwd_bias_relu_add) diff --git a/external/include/half/half.hpp b/external/include/half/half.hpp deleted file mode 100644 index 25f543881f6728e620dd0887302672a5b64fd03d..0000000000000000000000000000000000000000 --- a/external/include/half/half.hpp +++ /dev/null @@ -1,5670 +0,0 @@ -// half - IEEE 754-based half-precision floating-point library. -// -// Copyright (c) 2012-2019 Christian Rau -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -// associated documentation -// files (the "Software"), to deal in the Software without restriction, including without limitation -// the rights to use, copy, -// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit -// persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -// NOT LIMITED TO THE -// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR -// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF -// CONTRACT, TORT OR OTHERWISE, -// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// Version 2.1.0 - -/// \file -/// Main header file for half-precision functionality. - -#ifndef HALF_HALF_HPP -#define HALF_HALF_HPP - -#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) - -#if defined(__INTEL_COMPILER) -#define HALF_ICC_VERSION __INTEL_COMPILER -#elif defined(__ICC) -#define HALF_ICC_VERSION __ICC -#elif defined(__ICL) -#define HALF_ICC_VERSION __ICL -#else -#define HALF_ICC_VERSION 0 -#endif - -// check C++11 language features -#if defined(__clang__) // clang -#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if(defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ - !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ -#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#elif defined(__GNUC__) // gcc -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L -#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#endif -#define HALF_TWOS_COMPLEMENT_INT 1 -#elif defined(_MSC_VER) // Visual C++ -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#define HALF_TWOS_COMPLEMENT_INT 1 -#define HALF_POP_WARNINGS 1 -#pragma warning(push) -#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned -#endif - -// check C++11 library features -#include -#if defined(_LIBCPP_VERSION) // libc++ -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 -#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#ifndef HALF_ENABLE_CPP11_CSTDINT -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#ifndef HALF_ENABLE_CPP11_CMATH -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#ifndef HALF_ENABLE_CPP11_HASH -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#ifndef HALF_ENABLE_CPP11_CFENV -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#endif -#elif defined(__GLIBCXX__) // libstdc++ -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 -#ifdef __clang__ -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#else -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#endif -#endif -#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ -#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#endif -#undef HALF_GCC_VERSION -#undef HALF_ICC_VERSION - -// any error throwing C++ exceptions? -#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ - defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ - defined(HALF_ERRHANDLING_THROW_INEXACT) -#define HALF_ERRHANDLING_THROWS 1 -#endif - -// any error handling enabled? -#define HALF_ERRHANDLING \ - (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || HALF_ERRHANDLING_FENV || \ - HALF_ERRHANDLING_THROWS) - -#if HALF_ERRHANDLING -#define HALF_UNUSED_NOERR(name) name -#else -#define HALF_UNUSED_NOERR(name) -#endif - -// support constexpr -#if HALF_ENABLE_CPP11_CONSTEXPR -#define HALF_CONSTEXPR constexpr -#define HALF_CONSTEXPR_CONST constexpr -#if HALF_ERRHANDLING -#define HALF_CONSTEXPR_NOERR -#else -#define HALF_CONSTEXPR_NOERR constexpr -#endif -#else -#define HALF_CONSTEXPR -#define HALF_CONSTEXPR_CONST const -#define HALF_CONSTEXPR_NOERR -#endif - -// support noexcept -#if HALF_ENABLE_CPP11_NOEXCEPT -#define HALF_NOEXCEPT noexcept -#define HALF_NOTHROW noexcept -#else -#define HALF_NOEXCEPT -#define HALF_NOTHROW throw() -#endif - -// support thread storage -#if HALF_ENABLE_CPP11_THREAD_LOCAL -#define HALF_THREAD_LOCAL thread_local -#else -#define HALF_THREAD_LOCAL static -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if HALF_ENABLE_CPP11_TYPE_TRAITS -#include -#endif -#if HALF_ENABLE_CPP11_CSTDINT -#include -#endif -#if HALF_ERRHANDLING_ERRNO -#include -#endif -#if HALF_ENABLE_CPP11_CFENV -#include -#endif -#if HALF_ENABLE_CPP11_HASH -#include -#endif -#if HALF_ENABLE_F16C_INTRINSICS -#include -#endif - -#ifndef HALF_ENABLE_F16C_INTRINSICS -/// Enable F16C intruction set intrinsics. -/// Defining this to 1 enables the use of [F16C compiler -/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between -/// half-precision and single-precision values which may result in improved performance. This will -/// not perform additional checks -/// for support of the F16C instruction set, so an appropriate target platform is required when -/// enabling this feature. -/// -/// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which -/// some compilers do on supporting platforms. -#define HALF_ENABLE_F16C_INTRINSICS __F16C__ -#endif - -#ifdef HALF_DOXYGEN_ONLY -/// Type for internal floating-point computations. -/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to -/// override the internal -/// half-precision implementation to use this type for computing arithmetic operations and -/// mathematical function (if available). -/// This can result in improved performance for arithmetic operators and mathematical functions but -/// might cause results to -/// deviate from the specified half-precision rounding mode and inhibits proper detection of -/// half-precision exceptions. -#define HALF_ARITHMETIC_TYPE (undefined) - -/// Enable internal exception flags. -/// Defining this to 1 causes operations on half-precision values to raise internal floating-point -/// exception flags according to -/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). -#define HALF_ERRHANDLING_FLAGS 0 - -/// Enable exception propagation to `errno`. -/// Defining this to 1 causes operations on half-precision values to propagate floating-point -/// exceptions to -/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will -/// propagate domain errors as -/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow -/// errors as -/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be -/// propagated. -#define HALF_ERRHANDLING_ERRNO 0 - -/// Enable exception propagation to built-in floating-point platform. -/// Defining this to 1 causes operations on half-precision values to propagate floating-point -/// exceptions to the built-in -/// single- and double-precision implementation's exception flags using the -/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from -/// ``. However, this -/// does not work in reverse and single- or double-precision exceptions will not raise the -/// corresponding half-precision -/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. -#define HALF_ERRHANDLING_FENV 0 - -/// Throw C++ exception on domain errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified -/// message on domain errors. -#define HALF_ERRHANDLING_THROW_INVALID (undefined) - -/// Throw C++ exception on pole errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified -/// message on pole errors. -#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) - -/// Throw C++ exception on overflow errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified -/// message on overflows. -#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) - -/// Throw C++ exception on underflow errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the -/// specified message on underflows. -#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) - -/// Throw C++ exception on rounding errors. -/// Defining this to 1 causes operations on half-precision values to throw a -/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified -/// message on general rounding errors. -#define HALF_ERRHANDLING_THROW_INEXACT (undefined) -#endif - -#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT -/// Raise INEXACT exception on overflow. -/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in -/// addition. -/// These will be raised after any possible handling of the underflow exception. -#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 -#endif - -#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT -/// Raise INEXACT exception on underflow. -/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions -/// in addition. -/// These will be raised after any possible handling of the underflow exception. -/// -/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be -/// raised *only* when the result -/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) -/// subnormal result. -#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 -#endif - -/// Default rounding mode. -/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s -/// and more precise types -/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic -/// operations and mathematical -/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes -/// using their respective -/// constants or the equivalent values of -/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): -/// -/// `std::float_round_style` | value | rounding -/// ---------------------------------|-------|------------------------- -/// `std::round_indeterminate` | -1 | fastest -/// `std::round_toward_zero` | 0 | toward zero -/// `std::round_to_nearest` | 1 | to nearest (default) -/// `std::round_toward_infinity` | 2 | toward positive infinity -/// `std::round_toward_neg_infinity` | 3 | toward negative infinity -/// -/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest -/// representable value. It can even -/// be set to -/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) -/// to synchronize -/// the rounding mode with that of the built-in single-precision implementation (which is likely -/// `std::round_to_nearest`, though). -#ifndef HALF_ROUND_STYLE -#define HALF_ROUND_STYLE 1 // = std::round_to_nearest -#endif - -/// Value signaling overflow. -/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value -/// signaling the overflow of an -/// operation, in particular it just evaluates to positive infinity. -/// -/// **See also:** Documentation for -/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) -#define HUGE_VALH std::numeric_limits::infinity() - -/// Fast half-precision fma function. -/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a -/// separate -/// half-precision multiplication followed by an addition, which is always the case. -/// -/// **See also:** Documentation for -/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) -#define FP_FAST_FMAH 1 - -/// Half rounding mode. -/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode -/// used for -/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). -/// -/// **See also:** Documentation for -/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) -#define HLF_ROUNDS HALF_ROUND_STYLE - -#ifndef FP_ILOGB0 -#define FP_ILOGB0 INT_MIN -#endif -#ifndef FP_ILOGBNAN -#define FP_ILOGBNAN INT_MAX -#endif -#ifndef FP_SUBNORMAL -#define FP_SUBNORMAL 0 -#endif -#ifndef FP_ZERO -#define FP_ZERO 1 -#endif -#ifndef FP_NAN -#define FP_NAN 2 -#endif -#ifndef FP_INFINITE -#define FP_INFINITE 3 -#endif -#ifndef FP_NORMAL -#define FP_NORMAL 4 -#endif - -#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) -#define FE_INVALID 0x10 -#define FE_DIVBYZERO 0x08 -#define FE_OVERFLOW 0x04 -#define FE_UNDERFLOW 0x02 -#define FE_INEXACT 0x01 -#define FE_ALL_EXCEPT (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) -#endif - -/// Main namespace for half-precision functionality. -/// This namespace contains all the functionality provided by the library. -namespace half_float { -class half; - -#if HALF_ENABLE_CPP11_USER_LITERALS -/// Library-defined half-precision literals. -/// Import this namespace to enable half-precision floating-point literals: -/// ~~~~{.cpp} -/// using namespace half_float::literal; -/// half_float::half = 4.2_h; -/// ~~~~ -namespace literal { -half operator"" _h(long double); -} -#endif - -/// \internal -/// \brief Implementation details. -namespace detail { -#if HALF_ENABLE_CPP11_TYPE_TRAITS -/// Conditional type. -template -struct conditional : std::conditional -{ -}; - -/// Helper for tag dispatching. -template -struct bool_type : std::integral_constant -{ -}; -using std::false_type; -using std::true_type; - -/// Type traits for floating-point types. -template -struct is_float : std::is_floating_point -{ -}; -#else -/// Conditional type. -template -struct conditional -{ - typedef T type; -}; -template -struct conditional -{ - typedef F type; -}; - -/// Helper for tag dispatching. -template -struct bool_type -{ -}; -typedef bool_type true_type; -typedef bool_type false_type; - -/// Type traits for floating-point types. -template -struct is_float : false_type -{ -}; -template -struct is_float : is_float -{ -}; -template -struct is_float : is_float -{ -}; -template -struct is_float : is_float -{ -}; -template <> -struct is_float : true_type -{ -}; -template <> -struct is_float : true_type -{ -}; -template <> -struct is_float : true_type -{ -}; -#endif - -/// Type traits for floating-point bits. -template -struct bits -{ - typedef unsigned char type; -}; -template -struct bits : bits -{ -}; -template -struct bits : bits -{ -}; -template -struct bits : bits -{ -}; - -#if HALF_ENABLE_CPP11_CSTDINT -/// Unsigned integer of (at least) 16 bits width. -typedef std::uint_least16_t uint16; - -/// Fastest unsigned integer of (at least) 32 bits width. -typedef std::uint_fast32_t uint32; - -/// Fastest signed integer of (at least) 32 bits width. -typedef std::int_fast32_t int32; - -/// Unsigned integer of (at least) 32 bits width. -template <> -struct bits -{ - typedef std::uint_least32_t type; -}; - -/// Unsigned integer of (at least) 64 bits width. -template <> -struct bits -{ - typedef std::uint_least64_t type; -}; -#else -/// Unsigned integer of (at least) 16 bits width. -typedef unsigned short uint16; - -/// Fastest unsigned integer of (at least) 32 bits width. -typedef unsigned long uint32; - -/// Fastest unsigned integer of (at least) 32 bits width. -typedef long int32; - -/// Unsigned integer of (at least) 32 bits width. -template <> -struct bits - : conditional::digits >= 32, unsigned int, unsigned long> -{ -}; - -#if HALF_ENABLE_CPP11_LONG_LONG -/// Unsigned integer of (at least) 64 bits width. -template <> -struct bits : conditional::digits >= 64, - unsigned long, - unsigned long long> -{ -}; -#else -/// Unsigned integer of (at least) 64 bits width. -template <> -struct bits -{ - typedef unsigned long type; -}; -#endif -#endif - -#ifdef HALF_ARITHMETIC_TYPE -/// Type to use for arithmetic computations and mathematic functions internally. -typedef HALF_ARITHMETIC_TYPE internal_t; -#endif - -/// Tag type for binary construction. -struct binary_t -{ -}; - -/// Tag for binary construction. -HALF_CONSTEXPR_CONST binary_t binary = binary_t(); - -/// \name Implementation defined classification and arithmetic -/// \{ - -/// Check for infinity. -/// \tparam T argument type (builtin floating-point type) -/// \param arg value to query -/// \retval true if infinity -/// \retval false else -template -bool builtin_isinf(T arg) -{ -#if HALF_ENABLE_CPP11_CMATH - return std::isinf(arg); -#elif defined(_MSC_VER) - return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); -#else - return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); -#endif -} - -/// Check for NaN. -/// \tparam T argument type (builtin floating-point type) -/// \param arg value to query -/// \retval true if not a number -/// \retval false else -template -bool builtin_isnan(T arg) -{ -#if HALF_ENABLE_CPP11_CMATH - return std::isnan(arg); -#elif defined(_MSC_VER) - return ::_isnan(static_cast(arg)) != 0; -#else - return arg != arg; -#endif -} - -/// Check sign. -/// \tparam T argument type (builtin floating-point type) -/// \param arg value to query -/// \retval true if signbit set -/// \retval false else -template -bool builtin_signbit(T arg) -{ -#if HALF_ENABLE_CPP11_CMATH - return std::signbit(arg); -#else - return arg < T() || (arg == T() && T(1) / arg < T()); -#endif -} - -/// Platform-independent sign mask. -/// \param arg integer value in two's complement -/// \retval -1 if \a arg negative -/// \retval 0 if \a arg positive -inline uint32 sign_mask(uint32 arg) -{ - static const int N = std::numeric_limits::digits - 1; -#if HALF_TWOS_COMPLEMENT_INT - return static_cast(arg) >> N; -#else - return -((arg >> N) & 1); -#endif -} - -/// Platform-independent arithmetic right shift. -/// \param arg integer value in two's complement -/// \param i shift amount (at most 31) -/// \return \a arg right shifted for \a i bits with possible sign extension -inline uint32 arithmetic_shift(uint32 arg, int i) -{ -#if HALF_TWOS_COMPLEMENT_INT - return static_cast(arg) >> i; -#else - return static_cast(arg) / (static_cast(1) << i) - - ((arg >> (std::numeric_limits::digits - 1)) & 1); -#endif -} - -/// \} -/// \name Error handling -/// \{ - -/// Internal exception flags. -/// \return reference to global exception flags -inline int& errflags() -{ - HALF_THREAD_LOCAL int flags = 0; - return flags; -} - -/// Raise floating-point exception. -/// \param flags exceptions to raise -/// \param cond condition to raise exceptions for -inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) -{ -#if HALF_ERRHANDLING - if(!cond) - return; -#if HALF_ERRHANDLING_FLAGS - errflags() |= flags; -#endif -#if HALF_ERRHANDLING_ERRNO - if(flags & FE_INVALID) - errno = EDOM; - else if(flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) - errno = ERANGE; -#endif -#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV - std::feraiseexcept(flags); -#endif -#ifdef HALF_ERRHANDLING_THROW_INVALID - if(flags & FE_INVALID) - throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); -#endif -#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO - if(flags & FE_DIVBYZERO) - throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); -#endif -#ifdef HALF_ERRHANDLING_THROW_OVERFLOW - if(flags & FE_OVERFLOW) - throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); -#endif -#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW - if(flags & FE_UNDERFLOW) - throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); -#endif -#ifdef HALF_ERRHANDLING_THROW_INEXACT - if(flags & FE_INEXACT) - throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); -#endif -#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT - if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) - raise(FE_INEXACT); -#endif -#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT - if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) - raise(FE_INEXACT); -#endif -#endif -} - -/// Check and signal for any NaN. -/// \param x first half-precision value to check -/// \param y second half-precision value to check -/// \retval true if either \a x or \a y is NaN -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); -#endif - return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; -} - -/// Signal and silence signaling NaN. -/// \param nan half-precision NaN value -/// \return quiet NaN -/// \exception FE_INVALID if \a nan is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, !(nan & 0x200)); -#endif - return nan | 0x200; -} - -/// Signal and silence signaling NaNs. -/// \param x first half-precision value to check -/// \param y second half-precision value to check -/// \return quiet NaN -/// \exception FE_INVALID if \a x or \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, - ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); -#endif - return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); -} - -/// Signal and silence signaling NaNs. -/// \param x first half-precision value to check -/// \param y second half-precision value to check -/// \param z third half-precision value to check -/// \return quiet NaN -/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, - ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || - ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); -#endif - return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) - : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200); -} - -/// Select value or signaling NaN. -/// \param x preferred half-precision value -/// \param y ignored half-precision value except for signaling NaN -/// \return \a y if signaling NaN, \a x otherwise -/// \exception FE_INVALID if \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) -{ -#if HALF_ERRHANDLING - return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? signal(y) : x; -#else - return x; -#endif -} - -/// Raise domain error and return NaN. -/// return quiet NaN -/// \exception FE_INVALID -inline HALF_CONSTEXPR_NOERR unsigned int invalid() -{ -#if HALF_ERRHANDLING - raise(FE_INVALID); -#endif - return 0x7FFF; -} - -/// Raise pole error and return infinity. -/// \param sign half-precision value with sign bit only -/// \return half-precision infinity with sign of \a sign -/// \exception FE_DIVBYZERO -inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) -{ -#if HALF_ERRHANDLING - raise(FE_DIVBYZERO); -#endif - return sign | 0x7C00; -} - -/// Check value for underflow. -/// \param arg non-zero half-precision value to check -/// \return \a arg -/// \exception FE_UNDERFLOW if arg is subnormal -inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) -{ -#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT - raise(FE_UNDERFLOW, !(arg & 0x7C00)); -#endif - return arg; -} - -/// \} -/// \name Conversion and rounding -/// \{ - -/// Half-precision overflow. -/// \tparam R rounding mode to use -/// \param sign half-precision value with sign bit only -/// \return rounded overflowing half-precision value -/// \exception FE_OVERFLOW -template -HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) -{ -#if HALF_ERRHANDLING - raise(FE_OVERFLOW); -#endif - return (R == std::round_toward_infinity) - ? (sign + 0x7C00 - (sign >> 15)) - : (R == std::round_toward_neg_infinity) - ? (sign + 0x7BFF + (sign >> 15)) - : (R == std::round_toward_zero) ? (sign | 0x7BFF) : (sign | 0x7C00); -} - -/// Half-precision underflow. -/// \tparam R rounding mode to use -/// \param sign half-precision value with sign bit only -/// \return rounded underflowing half-precision value -/// \exception FE_UNDERFLOW -template -HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) -{ -#if HALF_ERRHANDLING - raise(FE_UNDERFLOW); -#endif - return (R == std::round_toward_infinity) - ? (sign + 1 - (sign >> 15)) - : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) : sign; -} - -/// Round half-precision number. -/// \tparam R rounding mode to use -/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results -/// \param value finite half-precision number to round -/// \param g guard bit (most significant discarded bit) -/// \param s sticky bit (or of all but the most significant discarded bits) -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded or \a I is `true` -template -HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) -{ -#if HALF_ERRHANDLING - value += (R == std::round_to_nearest) - ? (g & (s | value)) - : (R == std::round_toward_infinity) - ? (~(value >> 15) & (g | s)) - : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) : 0; - if((value & 0x7C00) == 0x7C00) - raise(FE_OVERFLOW); - else if(value & 0x7C00) - raise(FE_INEXACT, I || (g | s) != 0); - else - raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g | s) != 0); - return value; -#else - return (R == std::round_to_nearest) - ? (value + (g & (s | value))) - : (R == std::round_toward_infinity) - ? (value + (~(value >> 15) & (g | s))) - : (R == std::round_toward_neg_infinity) ? (value + ((value >> 15) & (g | s))) - : value; -#endif -} - -/// Round half-precision number to nearest integer value. -/// \tparam R rounding mode to use -/// \tparam E `true` for round to even, `false` for round away from zero -/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it -/// \param value half-precision value to round -/// \return half-precision bits for nearest integral value -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded and \a I is `true` -template -unsigned int integral(unsigned int value) -{ - unsigned int abs = value & 0x7FFF; - if(abs < 0x3C00) - { - raise(FE_INEXACT, I); - return ((R == std::round_to_nearest) - ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) - : (R == std::round_toward_infinity) - ? (0x3C00 & -(~(value >> 15) & (abs != 0))) - : (R == std::round_toward_neg_infinity) - ? (0x3C00 & -static_cast(value > 0x8000)) - : 0) | - (value & 0x8000); - } - if(abs >= 0x6400) - return (abs > 0x7C00) ? signal(value) : value; - unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; - raise(FE_INEXACT, I && (value & mask)); - return (((R == std::round_to_nearest) - ? ((1 << (exp - 1)) - (~(value >> exp) & E)) - : (R == std::round_toward_infinity) - ? (mask & ((value >> 15) - 1)) - : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) : 0) + - value) & - ~mask; -} - -/// Convert fixed point to half-precision floating-point. -/// \tparam R rounding mode to use -/// \tparam F number of fractional bits (at least 11) -/// \tparam S `true` for signed, `false` for unsigned -/// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F -/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results -/// \param m mantissa in Q1.F fixed point format -/// \param exp exponent -/// \param sign half-precision value with sign bit only -/// \param s sticky bit (or of all but the most significant already discarded bits) -/// \return value converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded or \a I is `true` -template -unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) -{ - if(S) - { - uint32 msign = sign_mask(m); - m = (m ^ msign) - msign; - sign = msign & 0x8000; - } - if(N) - for(; m < (static_cast(1) << F) && exp; m <<= 1, --exp) - ; - else if(exp < 0) - return rounded(sign + (m >> (F - 10 - exp)), - (m >> (F - 11 - exp)) & 1, - s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); - return rounded(sign + (exp << 10) + (m >> (F - 10)), - (m >> (F - 11)) & 1, - s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); -} - -/// Convert IEEE single-precision to half-precision. -/// Credit for this goes to [Jeroen van der -/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). -/// \tparam R rounding mode to use -/// \param value single-precision value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half_impl(float value, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), - (R == std::round_to_nearest) - ? _MM_FROUND_TO_NEAREST_INT - : (R == std::round_toward_zero) - ? _MM_FROUND_TO_ZERO - : (R == std::round_toward_infinity) - ? _MM_FROUND_TO_POS_INF - : (R == std::round_toward_neg_infinity) - ? _MM_FROUND_TO_NEG_INF - : _MM_FROUND_CUR_DIRECTION)); -#else - bits::type fbits; - std::memcpy(&fbits, &value, sizeof(float)); -#if 1 - unsigned int sign = (fbits >> 16) & 0x8000; - fbits &= 0x7FFFFFFF; - if(fbits >= 0x7F800000) - return sign | 0x7C00 | ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); - if(fbits >= 0x47800000) - return overflow(sign); - if(fbits >= 0x38800000) - return rounded(sign | (((fbits >> 23) - 112) << 10) | ((fbits >> 13) & 0x3FF), - (fbits >> 12) & 1, - (fbits & 0xFFF) != 0); - if(fbits >= 0x33000000) - { - int i = 125 - (fbits >> 23); - fbits = (fbits & 0x7FFFFF) | 0x800000; - return rounded(sign | (fbits >> (i + 1)), - (fbits >> i) & 1, - (fbits & ((static_cast(1) << i) - 1)) != 0); - } - if(fbits != 0) - return underflow(sign); - return sign; -#else - static const uint16 base_table[512] = { - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, - 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, - 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, - 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, - 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, - 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, - 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, - 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; - static const unsigned char shift_table[256] = { - 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, - 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, - 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; - int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; - fbits &= 0x7FFFFF; - uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); - return rounded(base_table[sexp] + (fbits >> i), - (m >> (i - 1)) & 1, - (((static_cast(1) << (i - 1)) - 1) & m) != 0); -#endif -#endif -} - -/// Convert IEEE double-precision to half-precision. -/// \tparam R rounding mode to use -/// \param value double-precision value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half_impl(double value, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - if(R == std::round_indeterminate) - return _mm_cvtsi128_si32( - _mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); -#endif - bits::type dbits; - std::memcpy(&dbits, &value, sizeof(double)); - uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; - unsigned int sign = (hi >> 16) & 0x8000; - hi &= 0x7FFFFFFF; - if(hi >= 0x7FF00000) - return sign | 0x7C00 | ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); - if(hi >= 0x40F00000) - return overflow(sign); - if(hi >= 0x3F100000) - return rounded(sign | (((hi >> 20) - 1008) << 10) | ((hi >> 10) & 0x3FF), - (hi >> 9) & 1, - ((hi & 0x1FF) | lo) != 0); - if(hi >= 0x3E600000) - { - int i = 1018 - (hi >> 20); - hi = (hi & 0xFFFFF) | 0x100000; - return rounded(sign | (hi >> (i + 1)), - (hi >> i) & 1, - ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); - } - if((hi | lo) != 0) - return underflow(sign); - return sign; -} - -/// Convert non-IEEE floating-point to half-precision. -/// \tparam R rounding mode to use -/// \tparam T source type (builtin floating-point type) -/// \param value floating-point value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half_impl(T value, ...) -{ - unsigned int hbits = static_cast(builtin_signbit(value)) << 15; - if(value == T()) - return hbits; - if(builtin_isnan(value)) - return hbits | 0x7FFF; - if(builtin_isinf(value)) - return hbits | 0x7C00; - int exp; - std::frexp(value, &exp); - if(exp > 16) - return overflow(hbits); - if(exp < -13) - value = std::ldexp(value, 25); - else - { - value = std::ldexp(value, 12 - exp); - hbits |= ((exp + 13) << 10); - } - T ival, frac = std::modf(value, &ival); - int m = std::abs(static_cast(ival)); - return rounded(hbits + (m >> 1), m & 1, frac != T()); -} - -/// Convert floating-point to half-precision. -/// \tparam R rounding mode to use -/// \tparam T source type (builtin floating-point type) -/// \param value floating-point value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half(T value) -{ - return float2half_impl(value, - bool_type < std::numeric_limits::is_iec559 && - sizeof(typename bits::type) == sizeof(T) > ()); -} - -/// Convert integer to half-precision floating-point. -/// \tparam R rounding mode to use -/// \tparam T type to convert (builtin integer type) -/// \param value integral value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int int2half(T value) -{ - unsigned int bits = static_cast(value < 0) << 15; - if(!value) - return bits; - if(bits) - value = -value; - if(value > 0xFFFF) - return overflow(bits); - unsigned int m = static_cast(value), exp = 24; - for(; m < 0x400; m <<= 1, --exp) - ; - for(; m > 0x7FF; m >>= 1, ++exp) - ; - bits |= (exp << 10) + m; - return (exp > 24) ? rounded( - bits, (value >> (exp - 25)) & 1, (((1 << (exp - 25)) - 1) & value) != 0) - : bits; -} - -/// Convert half-precision to IEEE single-precision. -/// Credit for this goes to [Jeroen van der -/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). -/// \param value half-precision value to convert -/// \return single-precision value -inline float half2float_impl(unsigned int value, float, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); -#else -#if 0 - bits::type fbits = static_cast::type>(value&0x8000) << 16; - int abs = value & 0x7FFF; - if(abs) - { - fbits |= 0x38000000 << static_cast(abs>=0x7C00); - for(; abs<0x400; abs<<=1,fbits-=0x800000) ; - fbits += static_cast::type>(abs) << 13; - } -#else - static const bits::type mantissa_table[2048] = { - 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, - 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, - 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, - 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, - 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, - 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, - 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, - 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, - 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, - 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, - 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, - 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, - 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, - 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, - 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, - 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, - 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, - 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, - 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, - 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, - 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, - 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, - 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, - 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, - 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, - 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, - 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, - 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, - 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, - 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, - 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, - 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, - 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, - 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, - 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, - 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, - 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, - 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, - 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, - 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, - 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, - 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, - 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, - 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, - 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, - 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, - 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, - 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, - 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, - 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, - 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, - 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, - 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, - 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, - 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, - 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, - 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, - 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, - 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, - 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, - 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, - 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, - 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, - 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, - 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, - 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, - 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, - 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, - 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, - 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, - 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, - 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, - 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, - 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, - 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, - 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, - 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, - 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, - 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, - 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, - 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, - 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, - 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, - 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, - 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, - 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, - 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, - 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, - 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, - 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, - 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, - 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, - 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, - 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, - 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, - 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, - 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, - 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, - 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, - 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, - 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, - 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, - 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, - 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, - 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, - 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, - 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, - 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, - 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, - 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, - 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, - 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, - 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, - 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, - 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, - 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, - 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, - 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, - 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, - 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, - 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, - 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, - 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, - 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, - 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, - 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, - 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, - 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, - 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, - 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, - 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, - 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, - 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, - 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, - 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, - 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, - 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, - 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, - 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, - 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, - 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, - 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, - 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, - 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, - 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, - 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, - 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, - 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, - 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, - 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, - 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, - 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, - 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, - 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, - 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, - 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, - 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, - 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, - 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, - 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, - 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, - 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, - 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, - 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, - 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, - 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, - 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, - 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, - 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, - 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, - 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, - 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, - 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, - 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, - 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, - 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, - 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, - 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, - 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, - 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, - 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, - 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, - 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, - 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, - 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, - 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, - 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, - 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, - 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, - 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, - 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, - 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, - 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, - 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, - 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, - 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, - 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, - 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, - 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, - 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, - 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, - 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, - 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, - 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, - 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, - 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, - 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, - 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, - 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, - 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, - 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, - 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, - 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, - 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, - 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, - 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, - 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, - 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, - 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, - 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, - 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, - 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, - 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, - 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, - 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, - 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, - 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, - 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, - 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, - 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, - 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, - 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, - 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, - 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, - 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, - 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, - 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, - 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, - 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, - 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, - 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, - 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, - 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, - 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, - 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, - 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, - 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, - 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, - 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, - 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, - 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, - 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, - 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, - 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, - 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, - 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, - 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, - 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, - 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, - 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, - 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, - 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, - 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, - 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, - 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, - 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, - 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, - 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, - 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, - 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, - 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, - 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, - 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, - 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, - 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, - 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, - 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, - 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, - 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, - 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, - 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, - 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, - 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, - 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, - 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, - 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, - 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, - 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, - 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, - 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, - 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, - 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, - 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; - static const bits::type exponent_table[64] = { - 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, - 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, - 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, - 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, - 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, - 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, - 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, - 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, - 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, - 0xC7800000}; - static const unsigned short offset_table[64] = { - 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; - bits::type fbits = - mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10]; -#endif - float out; - std::memcpy(&out, &fbits, sizeof(float)); - return out; -#endif -} - -/// Convert half-precision to IEEE double-precision. -/// \param value half-precision value to convert -/// \return double-precision value -inline double half2float_impl(unsigned int value, double, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); -#else - uint32 hi = static_cast(value & 0x8000) << 16; - unsigned int abs = value & 0x7FFF; - if(abs) - { - hi |= 0x3F000000 << static_cast(abs >= 0x7C00); - for(; abs < 0x400; abs <<= 1, hi -= 0x100000) - ; - hi += static_cast(abs) << 10; - } - bits::type dbits = static_cast::type>(hi) << 32; - double out; - std::memcpy(&out, &dbits, sizeof(double)); - return out; -#endif -} - -/// Convert half-precision to non-IEEE floating-point. -/// \tparam T type to convert to (builtin integer type) -/// \param value half-precision value to convert -/// \return floating-point value -template -T half2float_impl(unsigned int value, T, ...) -{ - T out; - unsigned int abs = value & 0x7FFF; - if(abs > 0x7C00) - out = - (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) - ? std::numeric_limits::signaling_NaN() - : std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); - else if(abs == 0x7C00) - out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - else if(abs > 0x3FF) - out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); - else - out = std::ldexp(static_cast(abs), -24); - return (value & 0x8000) ? -out : out; -} - -/// Convert half-precision to floating-point. -/// \tparam T type to convert to (builtin integer type) -/// \param value half-precision value to convert -/// \return floating-point value -template -T half2float(unsigned int value) -{ - return half2float_impl(value, - T(), - bool_type < std::numeric_limits::is_iec559 && - sizeof(typename bits::type) == sizeof(T) > ()); -} - -/// Convert half-precision floating-point to integer. -/// \tparam R rounding mode to use -/// \tparam E `true` for round to even, `false` for round away from zero -/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it -/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding -/// any implicit sign bits) -/// \param value half-precision value to convert -/// \return rounded integer value -/// \exception FE_INVALID if value is not representable in type \a T -/// \exception FE_INEXACT if value had to be rounded and \a I is `true` -template -T half2int(unsigned int value) -{ - unsigned int abs = value & 0x7FFF; - if(abs >= 0x7C00) - { - raise(FE_INVALID); - return (value & 0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); - } - if(abs < 0x3800) - { - raise(FE_INEXACT, I); - return (R == std::round_toward_infinity) - ? T(~(value >> 15) & (abs != 0)) - : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) : T(); - } - int exp = 25 - (abs >> 10); - unsigned int m = (value & 0x3FF) | 0x400; - int32 i = static_cast( - (exp <= 0) - ? (m << -exp) - : ((m + ((R == std::round_to_nearest) ? ((1 << (exp - 1)) - (~(m >> exp) & E)) - : (R == std::round_toward_infinity) - ? (((1 << exp) - 1) & ((value >> 15) - 1)) - : (R == std::round_toward_neg_infinity) - ? (((1 << exp) - 1) & -(value >> 15)) - : 0)) >> - exp)); - if((!std::numeric_limits::is_signed && (value & 0x8000)) || - (std::numeric_limits::digits < 16 && - ((value & 0x8000) ? (-i < std::numeric_limits::min()) - : (i > std::numeric_limits::max())))) - raise(FE_INVALID); - else if(I && exp > 0 && (m & ((1 << exp) - 1))) - raise(FE_INEXACT); - return static_cast((value & 0x8000) ? -i : i); -} - -/// \} -/// \name Mathematics -/// \{ - -/// upper part of 64-bit multiplication. -/// \tparam R rounding mode to use -/// \param x first factor -/// \param y second factor -/// \return upper 32 bit of \a x * \a y -template -uint32 mulhi(uint32 x, uint32 y) -{ - uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), - c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); - return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + - ((R == std::round_to_nearest) - ? ((c >> 15) & 1) - : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0); -} - -/// 64-bit multiplication. -/// \param x first factor -/// \param y second factor -/// \return upper 32 bit of \a x * \a y rounded to nearest -inline uint32 multiply64(uint32 x, uint32 y) -{ -#if HALF_ENABLE_CPP11_LONG_LONG - return static_cast( - (static_cast(x) * static_cast(y) + 0x80000000) >> - 32); -#else - return mulhi(x, y); -#endif -} - -/// 64-bit division. -/// \param x upper 32 bit of dividend -/// \param y divisor -/// \param s variable to store sticky bit for rounding -/// \return (\a x << 32) / \a y -inline uint32 divide64(uint32 x, uint32 y, int& s) -{ -#if HALF_ENABLE_CPP11_LONG_LONG - unsigned long long xx = static_cast(x) << 32; - return s = (xx % y != 0), static_cast(xx / y); -#else - y >>= 1; - uint32 rem = x, div = 0; - for(unsigned int i = 0; i < 32; ++i) - { - div <<= 1; - if(rem >= y) - { - rem -= y; - div |= 1; - } - rem <<= 1; - } - return s = rem > 1, div; -#endif -} - -/// Half precision positive modulus. -/// \tparam Q `true` to compute full quotient, `false` else -/// \tparam R `true` to compute signed remainder, `false` for positive remainder -/// \param x first operand as positive finite half-precision value -/// \param y second operand as positive finite half-precision value -/// \param quo adress to store quotient at, `nullptr` if \a Q `false` -/// \return modulus of \a x / \a y -template -unsigned int mod(unsigned int x, unsigned int y, int* quo = NULL) -{ - unsigned int q = 0; - if(x > y) - { - int absx = x, absy = y, expx = 0, expy = 0; - for(; absx < 0x400; absx <<= 1, --expx) - ; - for(; absy < 0x400; absy <<= 1, --expy) - ; - expx += absx >> 10; - expy += absy >> 10; - int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; - for(int d = expx - expy; d; --d) - { - if(!Q && mx == my) - return 0; - if(mx >= my) - { - mx -= my; - q += Q; - } - mx <<= 1; - q <<= static_cast(Q); - } - if(!Q && mx == my) - return 0; - if(mx >= my) - { - mx -= my; - ++q; - } - if(Q) - { - q &= (1 << (std::numeric_limits::digits - 1)) - 1; - if(!mx) - return *quo = q, 0; - } - for(; mx < 0x400; mx <<= 1, --expy) - ; - x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); - } - if(R) - { - unsigned int a, b; - if(y < 0x800) - { - a = (x < 0x400) ? (x << 1) : (x + 0x400); - b = y; - } - else - { - a = x; - b = y - 0x400; - } - if(a > b || (a == b && (q & 1))) - { - int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); - int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - - (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); - for(; m < 0x800 && exp > 1; m <<= 1, --exp) - ; - x = 0x8000 + ((exp - 1) << 10) + (m >> 1); - q += Q; - } - } - if(Q) - *quo = q; - return x; -} - -/// Fixed point square root. -/// \tparam F number of fractional bits -/// \param r radicand in Q1.F fixed point format -/// \param exp exponent -/// \return square root as Q1.F/2 -template -uint32 sqrt(uint32& r, int& exp) -{ - int i = exp & 1; - r <<= i; - exp = (exp - i) / 2; - uint32 m = 0; - for(uint32 bit = static_cast(1) << F; bit; bit >>= 2) - { - if(r < m + bit) - m >>= 1; - else - { - r -= m + bit; - m = (m >> 1) + bit; - } - } - return m; -} - -/// Fixed point binary exponential. -/// This uses the BKM algorithm in E-mode. -/// \param m exponent in [0,1) as Q0.31 -/// \param n number of iterations (at most 32) -/// \return 2 ^ \a m as Q1.31 -inline uint32 exp2(uint32 m, unsigned int n = 32) -{ - static const uint32 logs[] = { - 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, - 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, - 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, - 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, - 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; - if(!m) - return 0x80000000; - uint32 mx = 0x80000000, my = 0; - for(unsigned int i = 1; i < n; ++i) - { - uint32 mz = my + logs[i]; - if(mz <= m) - { - my = mz; - mx += mx >> i; - } - } - return mx; -} - -/// Fixed point binary logarithm. -/// This uses the BKM algorithm in L-mode. -/// \param m mantissa in [1,2) as Q1.30 -/// \param n number of iterations (at most 32) -/// \return log2(\a m) as Q0.31 -inline uint32 log2(uint32 m, unsigned int n = 32) -{ - static const uint32 logs[] = { - 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, - 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, - 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, - 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, - 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; - if(m == 0x40000000) - return 0; - uint32 mx = 0x40000000, my = 0; - for(unsigned int i = 1; i < n; ++i) - { - uint32 mz = mx + (mx >> i); - if(mz <= m) - { - mx = mz; - my += logs[i]; - } - } - return my; -} - -/// Fixed point sine and cosine. -/// This uses the CORDIC algorithm in rotation mode. -/// \param mz angle in [-pi/2,pi/2] as Q1.30 -/// \param n number of iterations (at most 31) -/// \return sine and cosine of \a mz as Q1.30 -inline std::pair sincos(uint32 mz, unsigned int n = 31) -{ - static const uint32 angles[] = { - 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, - 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, - 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, - 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, - 0x00000004, 0x00000002, 0x00000001}; - uint32 mx = 0x26DD3B6A, my = 0; - for(unsigned int i = 0; i < n; ++i) - { - uint32 sign = sign_mask(mz); - uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; - uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; - mx = tx; - my = ty; - mz -= (angles[i] ^ sign) - sign; - } - return std::make_pair(my, mx); -} - -/// Fixed point arc tangent. -/// This uses the CORDIC algorithm in vectoring mode. -/// \param my y coordinate as Q0.30 -/// \param mx x coordinate as Q0.30 -/// \param n number of iterations (at most 31) -/// \return arc tangent of \a my / \a mx as Q1.30 -inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) -{ - static const uint32 angles[] = { - 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, - 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, - 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, - 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, - 0x00000004, 0x00000002, 0x00000001}; - uint32 mz = 0; - for(unsigned int i = 0; i < n; ++i) - { - uint32 sign = sign_mask(my); - uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; - uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; - mx = tx; - my = ty; - mz += (angles[i] ^ sign) - sign; - } - return mz; -} - -/// Reduce argument for trigonometric functions. -/// \param abs half-precision floating-point value -/// \param k value to take quarter period -/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 -inline uint32 angle_arg(unsigned int abs, int& k) -{ - uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); - int exp = (abs >> 10) + (abs <= 0x3FF) - 15; - if(abs < 0x3A48) - return k = 0, m << (exp + 20); -#if HALF_ENABLE_CPP11_LONG_LONG - unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, - yi = (y + (mask >> 1)) & ~mask, f = y - yi; - uint32 sign = -static_cast(f >> 63); - k = static_cast(yi >> (62 - exp)); - return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), 0xC90FDAA2) ^ sign) - - sign; -#else - uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), - yl = (m * 0x36E4E442) & 0xFFFFFFFF; - uint32 mask = (static_cast(1) << (30 - exp)) - 1, yi = (yh + (mask >> 1)) & ~mask, - sign = -static_cast(yi > yh); - k = static_cast(yi >> (30 - exp)); - uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; - return (multiply64((exp > -1) - ? (((fh << (1 + exp)) & 0xFFFFFFFF) | ((fl & 0xFFFFFFFF) >> (31 - exp))) - : fh, - 0xC90FDAA2) ^ - sign) - - sign; -#endif -} - -/// Get arguments for atan2 function. -/// \param abs half-precision floating-point value -/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 -inline std::pair atan2_args(unsigned int abs) -{ - int exp = -15; - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; - int rexp = 2 * exp; - r = 0x40000000 - - ((rexp > -31) ? ((r >> -rexp) | ((r & ((static_cast(1) << -rexp) - 1)) != 0)) : 1); - for(rexp = 0; r < 0x40000000; r <<= 1, --rexp) - ; - uint32 mx = sqrt<30>(r, rexp); - int d = exp - rexp; - if(d < 0) - return std::make_pair((d < -14) ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) - : (my << (14 + d)), - (mx << 14) + (r << 13) / mx); - if(d > 0) - return std::make_pair(my << 14, - (d > 14) - ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) - : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); - return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); -} - -/// Get exponentials for hyperbolic computation -/// \param abs half-precision floating-point value -/// \param exp variable to take unbiased exponent of larger result -/// \param n number of BKM iterations (at most 32) -/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent -inline std::pair hyperbolic_args(unsigned int abs, int& exp, unsigned int n = 32) -{ - uint32 mx = detail::multiply64(static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, - 0xB8AA3B29), - my; - int e = (abs >> 10) + (abs <= 0x3FF); - if(e < 14) - { - exp = 0; - mx >>= 14 - e; - } - else - { - exp = mx >> (45 - e); - mx = (mx << (e - 14)) & 0x7FFFFFFF; - } - mx = exp2(mx, n); - int d = exp << 1, s; - if(mx > 0x80000000) - { - my = divide64(0x80000000, mx, s); - my |= s; - ++d; - } - else - my = mx; - return std::make_pair( - mx, (d < 31) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1); -} - -/// Postprocessing for binary exponential. -/// \tparam R rounding mode to use -/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results -/// \param m mantissa as Q1.31 -/// \param exp absolute value of unbiased exponent -/// \param esign sign of actual exponent -/// \param sign sign bit of result -/// \return value converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded or \a I is `true` -template -unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) -{ - int s = 0; - if(esign) - { - if(m > 0x80000000) - { - m = divide64(0x80000000, m, s); - ++exp; - } - if(exp > 25) - return underflow(sign); - else if(exp == 25) - return rounded(sign, 1, (m & 0x7FFFFFFF) != 0); - exp = -exp; - } - else if(exp > 15) - return overflow(sign); - return fixed2half(m, exp + 14, sign, s); -} - -/// Postprocessing for binary logarithm. -/// \tparam R rounding mode to use -/// \tparam L logarithm for base transformation as Q1.31 -/// \param m fractional part of logarithm as Q0.31 -/// \param ilog signed integer part of logarithm -/// \param exp biased exponent of result -/// \param sign sign bit of result -/// \return value base-transformed and converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) -{ - uint32 msign = sign_mask(ilog); - m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; - if(!m) - return 0; - for(; m < 0x80000000; m <<= 1, --exp) - ; - int i = m >= L, s; - exp += i; - m >>= 1 + i; - sign ^= msign & 0x8000; - if(exp < -11) - return underflow(sign); - m = divide64(m, L, s); - return fixed2half(m, exp, sign, 1); -} - -/// Hypotenuse square root and postprocessing. -/// \tparam R rounding mode to use -/// \param r mantissa as Q2.30 -/// \param exp unbiased exponent -/// \return square root converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int hypot_post(uint32 r, int exp) -{ - int i = r >> 31; - if((exp += i) > 46) - return overflow(); - if(exp < -34) - return underflow(); - r = (r >> i) | (r & i); - uint32 m = sqrt<30>(r, exp += 15); - return fixed2half(m, exp - 1, 0, r != 0); -} - -/// Division and postprocessing for tangents. -/// \tparam R rounding mode to use -/// \param my dividend as Q1.31 -/// \param mx divisor as Q1.31 -/// \param exp biased exponent of result -/// \param sign sign bit of result -/// \return quotient converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) -{ - int i = my >= mx, s; - exp += i; - if(exp > 29) - return overflow(sign); - if(exp < -11) - return underflow(sign); - uint32 m = divide64(my >> (i + 1), mx, s); - return fixed2half(m, exp, sign, s); -} - -/// Area function and postprocessing. -/// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = -/// log(x+sqrt(x^2+|-1))`. -/// \tparam R rounding mode to use -/// \tparam S `true` for asinh, `false` for acosh -/// \param arg half-precision argument -/// \return asinh|acosh(\a arg) converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int area(unsigned int arg) -{ - int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, ilog, i; - uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, my, r; - for(; abs < 0x400; abs <<= 1, --expy) - ; - expy += abs >> 10; - r = ((abs & 0x3FF) | 0x400) << 5; - r *= r; - i = r >> 31; - expy = 2 * expy + i; - r >>= i; - if(S) - { - if(expy < 0) - { - r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | - ((r & ((static_cast(1) << -expy) - 1)) != 0)) - : 1); - expy = 0; - } - else - { - r += 0x40000000 >> expy; - i = r >> 31; - r = (r >> i) | (r & i); - expy += i; - } - } - else - { - r -= 0x40000000 >> expy; - for(; r < 0x40000000; r <<= 1, --expy) - ; - } - my = sqrt<30>(r, expy); - my = (my << 15) + (r << 14) / my; - if(S) - { - mx >>= expy - expx; - ilog = expy; - } - else - { - my >>= expx - expy; - ilog = expx; - } - my += mx; - i = my >> 31; - static const int G = S && (R == std::round_to_nearest); - return log2_post( - log2(my >> i, 26 + S + G) + (G << 3), ilog + i, 17, arg & (static_cast(S) << 15)); -} - -/// Class for 1.31 unsigned floating-point computation -struct f31 -{ - /// Constructor. - /// \param mant mantissa as 1.31 - /// \param e exponent - HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} - - /// Constructor. - /// \param abs unsigned half-precision value - f31(unsigned int abs) : exp(-15) - { - for(; abs < 0x400; abs <<= 1, --exp) - ; - m = static_cast((abs & 0x3FF) | 0x400) << 21; - exp += (abs >> 10); - } - - /// Addition operator. - /// \param a first operand - /// \param b second operand - /// \return \a a + \a b - friend f31 operator+(f31 a, f31 b) - { - if(b.exp > a.exp) - std::swap(a, b); - int d = a.exp - b.exp; - uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); - int i = (m & 0xFFFFFFFF) < a.m; - return f31(((m + i) >> i) | 0x80000000, a.exp + i); - } - - /// Subtraction operator. - /// \param a first operand - /// \param b second operand - /// \return \a a - \a b - friend f31 operator-(f31 a, f31 b) - { - int d = a.exp - b.exp, exp = a.exp; - uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); - if(!m) - return f31(0, -32); - for(; m < 0x80000000; m <<= 1, --exp) - ; - return f31(m, exp); - } - - /// Multiplication operator. - /// \param a first operand - /// \param b second operand - /// \return \a a * \a b - friend f31 operator*(f31 a, f31 b) - { - uint32 m = multiply64(a.m, b.m); - int i = m >> 31; - return f31(m << (1 - i), a.exp + b.exp + i); - } - - /// Division operator. - /// \param a first operand - /// \param b second operand - /// \return \a a / \a b - friend f31 operator/(f31 a, f31 b) - { - int i = a.m >= b.m, s; - uint32 m = divide64((a.m + i) >> i, b.m, s); - return f31(m, a.exp - b.exp + i - 1); - } - - uint32 m; ///< mantissa as 1.31. - int exp; ///< exponent. -}; - -/// Error function and postprocessing. -/// This computes the value directly in Q1.31 using the approximations given -/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). -/// \tparam R rounding mode to use -/// \tparam C `true` for comlementary error function, `false` else -/// \param arg half-precision function argument -/// \return approximated value of error function in half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int erf(unsigned int arg) -{ - unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; - f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), - t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), t2 = t * t; - f31 e = ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + f31(0x82790637, -2) - - (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * - t / - ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) - : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); - return (!C || sign) - ? fixed2half( - 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) - : (e.exp < -25) - ? underflow() - : fixed2half(e.m >> 1, e.exp + 14, 0, e.m & 1); -} - -/// Gamma function and postprocessing. -/// This approximates the value of either the gamma function or its logarithm directly in Q1.31. -/// \tparam R rounding mode to use -/// \tparam L `true` for lograithm of gamma function, `false` for gamma function -/// \param arg half-precision floating-point value -/// \return lgamma/tgamma(\a arg) in half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if \a arg is not a positive integer -template -unsigned int gamma(unsigned int arg) -{ - /* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, - -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, - 0.0114684895434781459556 }; double t = arg + 4.65, s = p[0]; for(unsigned int i=0; i<5; ++i) - s += p[i+1] / (arg+i); - return std::log(s) + (arg-0.5)*std::log(t) - t; -*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); - unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; - bool bsign = sign != 0; - f31 z(abs), x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), - s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + - f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + f31(0xE1868CB7, 7) / x - - f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - - f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); - int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); - s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), i) / lbe; - if(x.exp != -1 || x.m != 0x80000000) - { - i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); - f31 l = f31((static_cast(t.exp) << (31 - i)) + (log2(t.m >> 1, 30) >> i), i) / lbe; - s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) - : (s + (x - f31(0x80000000, -1)) * l); - } - s = x.exp ? (s - t) : (t - s); - if(bsign) - { - if(z.exp >= 0) - { - sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; - for(z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; - z.m <<= 1, --z.exp) - ; - } - if(z.exp == -1) - z = f31(0x80000000, 0) - z; - if(z.exp < -1) - { - z = z * pi; - z.m = sincos(z.m >> (1 - z.exp), 30).first; - for(z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) - ; - } - else - z = f31(0x80000000, 0); - } - if(L) - { - if(bsign) - { - f31 l(0x92868247, 0); - if(z.exp < 0) - { - uint32 m = log2((z.m + 1) >> 1, 27); - z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); - for(; z.m < 0x80000000; z.m <<= 1, --z.exp) - ; - l = l + z / lbe; - } - sign = static_cast(x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) - << 15; - s = sign ? (s - l) : x.exp ? (l - s) : (l + s); - } - else - { - sign = static_cast(x.exp == 0) << 15; - if(s.exp < -24) - return underflow(sign); - if(s.exp > 15) - return overflow(sign); - } - } - else - { - s = s * lbe; - uint32 m; - if(s.exp < 0) - { - m = s.m >> -s.exp; - s.exp = 0; - } - else - { - m = (s.m << s.exp) & 0x7FFFFFFF; - s.exp = (s.m >> (31 - s.exp)); - } - s.m = exp2(m, 27); - if(!x.exp) - s = f31(0x80000000, 0) / s; - if(bsign) - { - if(z.exp < 0) - s = s * z; - s = pi / s; - if(s.exp < -24) - return underflow(sign); - } - else if(z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) - return ((s.exp + 14) << 10) + (s.m >> 21); - if(s.exp > 15) - return overflow(sign); - } - return fixed2half(s.m, s.exp + 14, sign); -} -/// \} - -template -struct half_caster; -} // namespace detail - -/// Half-precision floating-point type. -/// This class implements an IEEE-conformant half-precision floating-point type with the usual -/// arithmetic -/// operators and conversions. It is implicitly convertible to single-precision floating-point, -/// which makes artihmetic -/// expressions and functions with mixed-type operands to be of the most precise operand type. -/// -/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's -/// less strict and -/// extended definitions it is both a standard layout type and a trivially copyable type (even if -/// not a POD type), which -/// means it can be standard-conformantly copied using raw binary copies. But in this context some -/// more words about the -/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not -/// neccessarily have to be of -/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of -/// this type will most -/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of -/// the underlying 16-bit -/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an -/// actual size of 16 bits if -/// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this -/// should be the case on -/// nearly any reasonable platform. -/// -/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, -/// it is a reasonable -/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE -/// representation. -class half -{ - public: - /// \name Construction and assignment - /// \{ - - /// Default constructor. - /// This initializes the half to 0. Although this does not match the builtin types' - /// default-initialization semantics - /// and may be less efficient than no initialization, it is needed to provide proper - /// value-initialization semantics. - HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} - - /// Conversion constructor. - /// \param rhs float to convert - /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding - explicit half(float rhs) - : data_(static_cast(detail::float2half(rhs))) - { - } - - /// Conversion to single-precision. - /// \return single precision value representing expression value - operator float() const { return detail::half2float(data_); } - - /// Assignment operator. - /// \param rhs single-precision value to copy from - /// \return reference to this half - /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding - half& operator=(float rhs) - { - data_ = static_cast(detail::float2half(rhs)); - return *this; - } - - /// \} - /// \name Arithmetic updates - /// \{ - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to add - /// \return reference to this half - /// \exception FE_... according to operator+(half,half) - half& operator+=(half rhs) { return *this = *this + rhs; } - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to subtract - /// \return reference to this half - /// \exception FE_... according to operator-(half,half) - half& operator-=(half rhs) { return *this = *this - rhs; } - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to multiply with - /// \return reference to this half - /// \exception FE_... according to operator*(half,half) - half& operator*=(half rhs) { return *this = *this * rhs; } - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to divide by - /// \return reference to this half - /// \exception FE_... according to operator/(half,half) - half& operator/=(half rhs) { return *this = *this / rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to add - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator+=(float rhs) { return *this = *this + rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to subtract - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator-=(float rhs) { return *this = *this - rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to multiply with - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator*=(float rhs) { return *this = *this * rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to divide by - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator/=(float rhs) { return *this = *this / rhs; } - - /// \} - /// \name Increment and decrement - /// \{ - - /// Prefix increment. - /// \return incremented half value - /// \exception FE_... according to operator+(half,half) - half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } - - /// Prefix decrement. - /// \return decremented half value - /// \exception FE_... according to operator-(half,half) - half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } - - /// Postfix increment. - /// \return non-incremented half value - /// \exception FE_... according to operator+(half,half) - half operator++(int) - { - half out(*this); - ++*this; - return out; - } - - /// Postfix decrement. - /// \return non-decremented half value - /// \exception FE_... according to operator-(half,half) - half operator--(int) - { - half out(*this); - --*this; - return out; - } - /// \} - - private: - /// Rounding mode to use - static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); - - /// Constructor. - /// \param bits binary representation to set half to - HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT - : data_(static_cast(bits)) - { - } - - /// Internal binary representation - detail::uint16 data_; - -#ifndef HALF_DOXYGEN_ONLY - friend HALF_CONSTEXPR_NOERR bool operator==(half, half); - friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); - friend HALF_CONSTEXPR_NOERR bool operator<(half, half); - friend HALF_CONSTEXPR_NOERR bool operator>(half, half); - friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); - friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); - friend HALF_CONSTEXPR half operator-(half); - friend half operator+(half, half); - friend half operator-(half, half); - friend half operator*(half, half); - friend half operator/(half, half); - template - friend std::basic_ostream& operator<<(std::basic_ostream&, half); - template - friend std::basic_istream& operator>>(std::basic_istream&, half&); - friend HALF_CONSTEXPR half fabs(half); - friend half fmod(half, half); - friend half remainder(half, half); - friend half remquo(half, half, int*); - friend half fma(half, half, half); - friend HALF_CONSTEXPR_NOERR half fmax(half, half); - friend HALF_CONSTEXPR_NOERR half fmin(half, half); - friend half fdim(half, half); - friend half nanh(const char*); - friend half exp(half); - friend half exp2(half); - friend half expm1(half); - friend half log(half); - friend half log10(half); - friend half log2(half); - friend half log1p(half); - friend half sqrt(half); - friend half cbrt(half); - friend half hypot(half, half); - friend half hypot(half, half, half); - friend half pow(half, half); - friend void sincos(half, half*, half*); - friend half sin(half); - friend half cos(half); - friend half tan(half); - friend half asin(half); - friend half acos(half); - friend half atan(half); - friend half atan2(half, half); - friend half sinh(half); - friend half cosh(half); - friend half tanh(half); - friend half asinh(half); - friend half acosh(half); - friend half atanh(half); - friend half erf(half); - friend half erfc(half); - friend half lgamma(half); - friend half tgamma(half); - friend half ceil(half); - friend half floor(half); - friend half trunc(half); - friend half round(half); - friend long lround(half); - friend half rint(half); - friend long lrint(half); - friend half nearbyint(half); -#ifdef HALF_ENABLE_CPP11_LONG_LONG - friend long long llround(half); - friend long long llrint(half); -#endif - friend half frexp(half, int*); - friend half scalbln(half, long); - friend half modf(half, half*); - friend int ilogb(half); - friend half logb(half); - friend half nextafter(half, half); - friend half nexttoward(half, long double); - friend HALF_CONSTEXPR half copysign(half, half); - friend HALF_CONSTEXPR int fpclassify(half); - friend HALF_CONSTEXPR bool isfinite(half); - friend HALF_CONSTEXPR bool isinf(half); - friend HALF_CONSTEXPR bool isnan(half); - friend HALF_CONSTEXPR bool isnormal(half); - friend HALF_CONSTEXPR bool signbit(half); - friend HALF_CONSTEXPR bool isgreater(half, half); - friend HALF_CONSTEXPR bool isgreaterequal(half, half); - friend HALF_CONSTEXPR bool isless(half, half); - friend HALF_CONSTEXPR bool islessequal(half, half); - friend HALF_CONSTEXPR bool islessgreater(half, half); - template - friend struct detail::half_caster; - friend class std::numeric_limits; -#if HALF_ENABLE_CPP11_HASH - friend struct std::hash; -#endif -#if HALF_ENABLE_CPP11_USER_LITERALS - friend half literal::operator"" _h(long double); -#endif -#endif -}; - -#if HALF_ENABLE_CPP11_USER_LITERALS -namespace literal { -/// Half literal. -/// While this returns a properly rounded half-precision value, half literals can unfortunately not -/// be constant -/// expressions due to rather involved conversions. So don't expect this to be a literal literal -/// without involving -/// conversion operations at runtime. It is a convenience feature, not a performance optimization. -/// \param value literal value -/// \return half with of given value (possibly rounded) -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator"" _h(long double value) -{ - return half(detail::binary, detail::float2half(value)); -} -} // namespace literal -#endif - -namespace detail { -/// Helper class for half casts. -/// This class template has to be specialized for all valid cast arguments to define an appropriate -/// static -/// `cast` member function and a corresponding `type` member denoting its return type. -/// \tparam T destination type -/// \tparam U source type -/// \tparam R rounding mode to use -template -struct half_caster -{ -}; -template -struct half_caster -{ -#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS - static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); -#endif - - static half cast(U arg) { return cast_impl(arg, is_float()); }; - - private: - static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } - static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } -}; -template -struct half_caster -{ -#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS - static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); -#endif - - static T cast(half arg) { return cast_impl(arg, is_float()); } - - private: - static T cast_impl(half arg, true_type) { return half2float(arg.data_); } - static T cast_impl(half arg, false_type) { return half2int(arg.data_); } -}; -template -struct half_caster -{ - static half cast(half arg) { return arg; } -}; -} // namespace detail -} // namespace half_float - -/// Extensions to the C++ standard library. -namespace std { -/// Numeric limits for half-precision floats. -/// **See also:** Documentation for -/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) -template <> -class numeric_limits -{ - public: - /// Is template specialization. - static HALF_CONSTEXPR_CONST bool is_specialized = true; - - /// Supports signed values. - static HALF_CONSTEXPR_CONST bool is_signed = true; - - /// Is not an integer type. - static HALF_CONSTEXPR_CONST bool is_integer = false; - - /// Is not exact. - static HALF_CONSTEXPR_CONST bool is_exact = false; - - /// Doesn't provide modulo arithmetic. - static HALF_CONSTEXPR_CONST bool is_modulo = false; - - /// Has a finite set of values. - static HALF_CONSTEXPR_CONST bool is_bounded = true; - - /// IEEE conformant. - static HALF_CONSTEXPR_CONST bool is_iec559 = true; - - /// Supports infinity. - static HALF_CONSTEXPR_CONST bool has_infinity = true; - - /// Supports quiet NaNs. - static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; - - /// Supports signaling NaNs. - static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; - - /// Supports subnormal values. - static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; - - /// Supports no denormalization detection. - static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; - -#if HALF_ERRHANDLING_THROWS - static HALF_CONSTEXPR_CONST bool traps = true; -#else - /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is - /// acitvated. - static HALF_CONSTEXPR_CONST bool traps = false; -#endif - - /// Does not support no pre-rounding underflow detection. - static HALF_CONSTEXPR_CONST bool tinyness_before = false; - - /// Rounding mode. - static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; - - /// Significant digits. - static HALF_CONSTEXPR_CONST int digits = 11; - - /// Significant decimal digits. - static HALF_CONSTEXPR_CONST int digits10 = 3; - - /// Required decimal digits to represent all possible values. - static HALF_CONSTEXPR_CONST int max_digits10 = 5; - - /// Number base. - static HALF_CONSTEXPR_CONST int radix = 2; - - /// One more than smallest exponent. - static HALF_CONSTEXPR_CONST int min_exponent = -13; - - /// Smallest normalized representable power of 10. - static HALF_CONSTEXPR_CONST int min_exponent10 = -4; - - /// One more than largest exponent - static HALF_CONSTEXPR_CONST int max_exponent = 16; - - /// Largest finitely representable power of 10. - static HALF_CONSTEXPR_CONST int max_exponent10 = 4; - - /// Smallest positive normal value. - static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x0400); - } - - /// Smallest finite value. - static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0xFBFF); - } - - /// Largest finite value. - static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7BFF); - } - - /// Difference between 1 and next representable value. - static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x1400); - } - - /// Maximum rounding error in ULP (units in the last place). - static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, - (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00); - } - - /// Positive infinity. - static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7C00); - } - - /// Quiet NaN. - static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7FFF); - } - - /// Signaling NaN. - static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7DFF); - } - - /// Smallest positive subnormal value. - static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x0001); - } -}; - -#if HALF_ENABLE_CPP11_HASH -/// Hash function for half-precision floats. -/// This is only defined if C++11 `std::hash` is supported and enabled. -/// -/// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) -template <> -struct hash -{ - /// Type of function argument. - typedef half_float::half argument_type; - - /// Function return type. - typedef size_t result_type; - - /// Compute hash function. - /// \param arg half to hash - /// \return hash value - result_type operator()(argument_type arg) const - { - return hash()(arg.data_ & - -static_cast(arg.data_ != 0x8000)); - } -}; -#endif -} // namespace std - -namespace half_float { -/// \anchor compop -/// \name Comparison operators -/// \{ - -/// Comparison for equality. -/// \param x first operand -/// \param y second operand -/// \retval true if operands equal -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); -} - -/// Comparison for inequality. -/// \param x first operand -/// \param y second operand -/// \retval true if operands not equal -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) -{ - return detail::compsignal(x.data_, y.data_) || - (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); -} - -/// Comparison for less than. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less than \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// Comparison for greater than. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater than \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// Comparison for less equal. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less equal \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// Comparison for greater equal. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater equal \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// \} -/// \anchor arithmetics -/// \name Arithmetic operators -/// \{ - -/// Identity. -/// \param arg operand -/// \return unchanged operand -inline HALF_CONSTEXPR half operator+(half arg) { return arg; } - -/// Negation. -/// \param arg operand -/// \return negated operand -inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_ ^ 0x8000); } - -/// Addition. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return sum of half expressions -/// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator+(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) + - detail::half2float(y.data_))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; - bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absy != 0x7C00) ? x.data_ - : (sub && absx == 0x7C00) ? detail::invalid() : y.data_); - if(!absx) - return absy ? y - : half(detail::binary, - (half::round_style == std::round_toward_neg_infinity) - ? (x.data_ | y.data_) - : (x.data_ & y.data_)); - if(!absy) - return x; - unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; - if(absy > absx) - std::swap(absx, absy); - int exp = (absx >> 10) + (absx <= 0x3FF), d = exp - (absy >> 10) - (absy <= 0x3FF), - mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; - if(d < 13) - { - my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; - my = (my >> d) | ((my & ((1 << d) - 1)) != 0); - } - else - my = 1; - if(sub) - { - if(!(mx -= my)) - return half(detail::binary, - static_cast(half::round_style == std::round_toward_neg_infinity) - << 15); - for(; mx < 0x2000 && exp > 1; mx <<= 1, --exp) - ; - } - else - { - mx += my; - int i = mx >> 14; - if((exp += i) > 30) - return half(detail::binary, detail::overflow(sign)); - mx = (mx >> i) | (mx & i); - } - return half(detail::binary, - detail::rounded( - sign + ((exp - 1) << 10) + (mx >> 3), (mx >> 2) & 1, (mx & 0x3) != 0)); -#endif -} - -/// Subtraction. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return difference of half expressions -/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator-(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) - - detail::half2float(y.data_))); -#else - return x + -y; -#endif -} - -/// Multiplication. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return product of half expressions -/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator*(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) * - detail::half2float(y.data_))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; - unsigned int sign = (x.data_ ^ y.data_) & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) - ? detail::invalid() - : (sign | 0x7C00)); - if(!absx || !absy) - return half(detail::binary, sign); - for(; absx < 0x400; absx <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, --exp) - ; - detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * - static_cast((absy & 0x3FF) | 0x400); - int i = m >> 21, s = m & i; - exp += (absx >> 10) + (absy >> 10) + i; - if(exp > 29) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -11) - return half(detail::binary, detail::underflow(sign)); - return half( - detail::binary, - detail::fixed2half(m >> i, exp, sign, s)); -#endif -} - -/// Division. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return quotient of half expressions -/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is -/// signaling NaN -/// \exception FE_DIVBYZERO if dividing finite value by 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator/(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) / - detail::half2float(y.data_))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; - unsigned int sign = (x.data_ ^ y.data_) & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == absy) ? detail::invalid() - : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); - if(!absx) - return half(detail::binary, absy ? sign : detail::invalid()); - if(!absy) - return half(detail::binary, detail::pole(sign)); - for(; absx < 0x400; absx <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, ++exp) - ; - detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; - int i = mx < my; - exp += (absx >> 10) - (absy >> 10) - i; - if(exp > 29) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -11) - return half(detail::binary, detail::underflow(sign)); - mx <<= 12 + i; - my <<= 1; - return half(detail::binary, - detail::fixed2half( - mx / my, exp, sign, mx % my != 0)); -#endif -} - -/// \} -/// \anchor streaming -/// \name Input and output -/// \{ - -/// Output operator. -/// This uses the built-in functionality for streaming out floating-point numbers. -/// \param out output stream to write into -/// \param arg half expression to write -/// \return reference to output stream -template -std::basic_ostream& operator<<(std::basic_ostream& out, half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return out << detail::half2float(arg.data_); -#else - return out << detail::half2float(arg.data_); -#endif -} - -/// Input operator. -/// This uses the built-in functionality for streaming in floating-point numbers, specifically -/// double precision floating -/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the -/// input string is first -/// rounded to double precision using the underlying platform's current floating-point rounding mode -/// before being rounded -/// to half-precision using the library's half-precision rounding mode. -/// \param in input stream to read from -/// \param arg half to read into -/// \return reference to input stream -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -template -std::basic_istream& operator>>(std::basic_istream& in, half& arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t f; -#else - double f; -#endif - if(in >> f) - arg.data_ = detail::float2half(f); - return in; -} - -/// \} -/// \anchor basic -/// \name Basic mathematical operations -/// \{ - -/// Absolute value. -/// **See also:** Documentation for -/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). -/// \param arg operand -/// \return absolute value of \a arg -inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_ & 0x7FFF); } - -/// Absolute value. -/// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). -/// \param arg operand -/// \return absolute value of \a arg -inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } - -/// Remainder of division. -/// **See also:** Documentation for -/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). -/// \param x first operand -/// \param y second operand -/// \return remainder of floating-point division. -/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN -inline half fmod(half x, half y) -{ - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == 0x7C00) ? detail::invalid() : x.data_); - if(!absy) - return half(detail::binary, detail::invalid()); - if(!absx) - return x; - if(absx == absy) - return half(detail::binary, sign); - return half(detail::binary, sign | detail::mod(absx, absy)); -} - -/// Remainder of division. -/// **See also:** Documentation for -/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). -/// \param x first operand -/// \param y second operand -/// \return remainder of floating-point division. -/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN -inline half remainder(half x, half y) -{ - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == 0x7C00) ? detail::invalid() : x.data_); - if(!absy) - return half(detail::binary, detail::invalid()); - if(absx == absy) - return half(detail::binary, sign); - return half(detail::binary, sign ^ detail::mod(absx, absy)); -} - -/// Remainder of division. -/// **See also:** Documentation for -/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). -/// \param x first operand -/// \param y second operand -/// \param quo address to store some bits of quotient at -/// \return remainder of floating-point division. -/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN -inline half remquo(half x, half y, int* quo) -{ - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == 0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); - if(!absy) - return half(detail::binary, detail::invalid()); - bool qsign = ((value ^ y.data_) & 0x8000) != 0; - int q = 1; - if(absx != absy) - value ^= detail::mod(absx, absy, &q); - return *quo = qsign ? -q : q, half(detail::binary, value); -} - -/// Fused multiply add. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). -/// \param x first operand -/// \param y second operand -/// \param z third operand -/// \return ( \a x * \a y ) + \a z rounded as one operation. -/// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet -/// NaN and no argument is a signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition -inline half fma(half x, half y, half z) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t fx = detail::half2float(x.data_), - fy = detail::half2float(y.data_), - fz = detail::half2float(z.data_); -#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA - return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); -#else - return half(detail::binary, detail::float2half(fx * fy + fz)); -#endif -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; - unsigned int sign = (x.data_ ^ y.data_) & 0x8000; - bool sub = ((sign ^ z.data_) & 0x8000) != 0; - if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) - return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) - ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) - : (absx == 0x7C00) ? half(detail::binary, - (!absy || (sub && absz == 0x7C00)) ? detail::invalid() - : (sign | 0x7C00)) - : (absy == 0x7C00) ? half(detail::binary, - (!absx || (sub && absz == 0x7C00)) - ? detail::invalid() - : (sign | 0x7C00)) - : z; - if(!absx || !absy) - return absz - ? z - : half(detail::binary, - (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign) - : (z.data_ & sign)); - for(; absx < 0x400; absx <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, --exp) - ; - detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * - static_cast((absy & 0x3FF) | 0x400); - int i = m >> 21; - exp += (absx >> 10) + (absy >> 10) + i; - m <<= 3 - i; - if(absz) - { - int expz = 0; - for(; absz < 0x400; absz <<= 1, --expz) - ; - expz += absz >> 10; - detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) << 13; - if(expz > exp || (expz == exp && mz > m)) - { - std::swap(m, mz); - std::swap(exp, expz); - if(sub) - sign = z.data_ & 0x8000; - } - int d = exp - expz; - mz = (d < 23) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; - if(sub) - { - m = m - mz; - if(!m) - return half( - detail::binary, - static_cast(half::round_style == std::round_toward_neg_infinity) - << 15); - for(; m < 0x800000; m <<= 1, --exp) - ; - } - else - { - m += mz; - i = m >> 24; - m = (m >> i) | (m & i); - exp += i; - } - } - if(exp > 30) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -10) - return half(detail::binary, detail::underflow(sign)); - return half(detail::binary, - detail::fixed2half(m, exp - 1, sign)); -#endif -} - -/// Maximum of half expressions. -/// **See also:** Documentation for -/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). -/// \param x first operand -/// \param y second operand -/// \return maximum of operands, ignoring quiet NaNs -/// \exception FE_INVALID if \a x or \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) -{ - return half(detail::binary, - (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) - ? detail::select(y.data_, x.data_) - : detail::select(x.data_, y.data_)); -} - -/// Minimum of half expressions. -/// **See also:** Documentation for -/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). -/// \param x first operand -/// \param y second operand -/// \return minimum of operands, ignoring quiet NaNs -/// \exception FE_INVALID if \a x or \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) -{ - return half(detail::binary, - (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) - ? detail::select(y.data_, x.data_) - : detail::select(x.data_, y.data_)); -} - -/// Positive difference. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). -/// \param x first operand -/// \param y second operand -/// \return \a x - \a y or 0 if difference negative -/// \exception FE_... according to operator-(half,half) -inline half fdim(half x, half y) -{ - if(isnan(x) || isnan(y)) - return half(detail::binary, detail::signal(x.data_, y.data_)); - return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) - ? half(detail::binary, 0) - : (x - y); -} - -/// Get NaN value. -/// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). -/// \param arg string code -/// \return quiet NaN -inline half nanh(const char* arg) -{ - unsigned int value = 0x7FFF; - while(*arg) - value ^= static_cast(*arg++) & 0xFF; - return half(detail::binary, value); -} - -/// \} -/// \anchor exponential -/// \name Exponential functions -/// \{ - -/// Exponential function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). -/// \param arg function argument -/// \return e raised to \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half exp(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::exp(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) - : detail::signal(arg.data_)); - if(abs >= 0x4C80) - return half(detail::binary, - (arg.data_ & 0x8000) ? detail::underflow() - : detail::overflow()); - detail::uint32 m = detail::multiply64( - static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); - int e = (abs >> 10) + (abs <= 0x3FF), exp; - if(e < 14) - { - exp = 0; - m >>= 14 - e; - } - else - { - exp = m >> (45 - e); - m = (m << (e - 14)) & 0x7FFFFFFF; - } - return half(detail::binary, - detail::exp2_post( - detail::exp2(m, 26), exp, (arg.data_ & 0x8000) != 0)); -#endif -} - -/// Binary exponential. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). -/// \param arg function argument -/// \return 2 raised to \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half exp2(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::exp2(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) - : detail::signal(arg.data_)); - if(abs >= 0x4E40) - return half(detail::binary, - (arg.data_ & 0x8000) ? detail::underflow() - : detail::overflow()); - int e = (abs >> 10) + (abs <= 0x3FF), exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); - detail::uint32 m = detail::exp2((static_cast(exp) << (6 + e)) & 0x7FFFFFFF, 28); - exp >>= 25 - e; - if(m == 0x80000000) - { - if(arg.data_ & 0x8000) - exp = -exp; - else if(exp > 15) - return half(detail::binary, detail::overflow()); - return half(detail::binary, - detail::fixed2half(m, exp + 14)); - } - return half(detail::binary, - detail::exp2_post(m, exp, (arg.data_ & 0x8000) != 0)); -#endif -} - -/// Exponential minus one. -/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for -/// `std::round_to_nearest` -/// and in <1% of inputs for any other rounding mode. -/// -/// **See also:** Documentation for -/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). -/// \param arg function argument -/// \return e raised to \a arg and subtracted by 1 -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half expm1(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::expm1(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) : detail::signal(arg.data_)); - if(abs >= 0x4A00) - return half(detail::binary, - (arg.data_ & 0x8000) ? detail::rounded(0xBBFF, 1, 1) - : detail::overflow()); - detail::uint32 m = detail::multiply64( - static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); - int e = (abs >> 10) + (abs <= 0x3FF), exp; - if(e < 14) - { - exp = 0; - m >>= 14 - e; - } - else - { - exp = m >> (45 - e); - m = (m << (e - 14)) & 0x7FFFFFFF; - } - m = detail::exp2(m); - if(sign) - { - int s = 0; - if(m > 0x80000000) - { - ++exp; - m = detail::divide64(0x80000000, m, s); - } - m = 0x80000000 - - ((m >> exp) | ((m & ((static_cast(1) << exp) - 1)) != 0) | s); - exp = 0; - } - else - m -= (exp < 31) ? (0x80000000 >> exp) : 1; - for(exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) - ; - if(exp > 29) - return half(detail::binary, detail::overflow()); - return half(detail::binary, - detail::rounded( - sign + (exp << 10) + (m >> 21), (m >> 20) & 1, (m & 0xFFFFF) != 0)); -#endif -} - -/// Natural logarithm. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). -/// \param arg function argument -/// \return logarithm of \a arg to base e -/// \exception FE_INVALID for signaling NaN or negative argument -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::log(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(arg.data_ & 0x8000) - return half(detail::binary, - (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs >= 0x7C00) - return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - return half(detail::binary, - detail::log2_post( - detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, - exp, - 17)); -#endif -} - -/// Common logarithm. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). -/// \param arg function argument -/// \return logarithm of \a arg to base 10 -/// \exception FE_INVALID for signaling NaN or negative argument -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log10(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::log10(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(arg.data_ & 0x8000) - return half(detail::binary, - (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs >= 0x7C00) - return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - switch(abs) - { - case 0x4900: return half(detail::binary, 0x3C00); - case 0x5640: return half(detail::binary, 0x4000); - case 0x63D0: return half(detail::binary, 0x4200); - case 0x70E2: return half(detail::binary, 0x4400); - } - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - return half(detail::binary, - detail::log2_post( - detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, - exp, - 16)); -#endif -} - -/// Binary logarithm. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). -/// \param arg function argument -/// \return logarithm of \a arg to base 2 -/// \exception FE_INVALID for signaling NaN or negative argument -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log2(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::log2(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(arg.data_ & 0x8000) - return half(detail::binary, - (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs >= 0x7C00) - return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - if(abs == 0x3C00) - return half(detail::binary, 0); - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += (abs >> 10); - if(!(abs & 0x3FF)) - { - unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) << 6; - for(exp = 18; m < 0x400; m <<= 1, --exp) - ; - return half(detail::binary, value + (exp << 10) + m); - } - detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), - m = (((ilog << 27) + - (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, - 28) >> - 4)) ^ - sign) - - sign; - if(!m) - return half(detail::binary, 0); - for(exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) - ; - for(; m > 0xFFFFFFF; m >>= 1, ++exp) - s |= m & 1; - return half( - detail::binary, - detail::fixed2half(m, exp, sign & 0x8000, s)); -#endif -} - -/// Natural logarithm plus one. -/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for -/// `std::round_to_nearest` -/// and in ~1% of inputs for any other rounding mode. -/// -/// **See also:** Documentation for -/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). -/// \param arg function argument -/// \return logarithm of \a arg plus 1 to base e -/// \exception FE_INVALID for signaling NaN or argument <-1 -/// \exception FE_DIVBYZERO for -1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log1p(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::log1p(detail::half2float(arg.data_)))); -#else - if(arg.data_ >= 0xBC00) - return half(detail::binary, - (arg.data_ == 0xBC00) - ? detail::pole(0x8000) - : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; - if(arg.data_ & 0x8000) - { - m = 0x40000000 - (m >> -exp); - for(exp = 0; m < 0x40000000; m <<= 1, --exp) - ; - } - else - { - if(exp < 0) - { - m = 0x40000000 + (m >> -exp); - exp = 0; - } - else - { - m += 0x40000000 >> exp; - int i = m >> 31; - m >>= i; - exp += i; - } - } - return half(detail::binary, - detail::log2_post(detail::log2(m), exp, 17)); -#endif -} - -/// \} -/// \anchor power -/// \name Power functions -/// \{ - -/// Square root. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). -/// \param arg function argument -/// \return square root of \a arg -/// \exception FE_INVALID for signaling NaN and negative arguments -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half sqrt(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::sqrt(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = 15; - if(!abs || arg.data_ >= 0x7C00) - return half(detail::binary, - (abs > 0x7C00) ? detail::signal(arg.data_) - : (arg.data_ > 0x8000) ? detail::invalid() : arg.data_); - for(; abs < 0x400; abs <<= 1, --exp) - ; - detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, - m = detail::sqrt<20>(r, exp += abs >> 10); - return half( - detail::binary, - detail::rounded((exp << 10) + (m & 0x3FF), r > m, r != 0)); -#endif -} - -/// Cubic root. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). -/// \param arg function argument -/// \return cubic root of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half cbrt(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::cbrt(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs || abs == 0x3C00 || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --exp) - ; - detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, - m = (((ilog << 27) + - (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, - 24) >> - 4)) ^ - sign) - - sign; - for(exp = 2; m < 0x80000000; m <<= 1, --exp) - ; - m = detail::multiply64(m, 0xAAAAAAAB); - int i = m >> 31, s; - exp += i; - m <<= 1 - i; - if(exp < 0) - { - f = m >> -exp; - exp = 0; - } - else - { - f = (m << exp) & 0x7FFFFFFF; - exp = m >> (31 - exp); - } - m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); - if(sign) - { - if(m > 0x80000000) - { - m = detail::divide64(0x80000000, m, s); - ++exp; - } - exp = -exp; - } - return half(detail::binary, - (half::round_style == std::round_to_nearest) - ? detail::fixed2half( - m, exp + 14, arg.data_ & 0x8000) - : detail::fixed2half( - (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); -#endif -} - -/// Hypotenuse function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). -/// \param x first argument -/// \param y second argument -/// \return square root of sum of squares without internal over- or underflows -/// \exception FE_INVALID if \a x or \a y is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root -inline half hypot(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t fx = detail::half2float(x.data_), - fy = detail::half2float(y.data_); -#if HALF_ENABLE_CPP11_CMATH - return half(detail::binary, detail::float2half(std::hypot(fx, fy))); -#else - return half(detail::binary, - detail::float2half(std::sqrt(fx * fx + fy * fy))); -#endif -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx == 0x7C00) ? detail::select(0x7C00, y.data_) - : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) - : detail::signal(x.data_, y.data_)); - if(!absx) - return half(detail::binary, absy ? detail::check_underflow(absy) : 0); - if(!absy) - return half(detail::binary, detail::check_underflow(absx)); - if(absy > absx) - std::swap(absx, absy); - for(; absx < 0x400; absx <<= 1, --expx) - ; - for(; absy < 0x400; absy <<= 1, --expy) - ; - detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; - mx *= mx; - my *= my; - int ix = mx >> 21, iy = my >> 21; - expx = 2 * (expx + (absx >> 10)) - 15 + ix; - expy = 2 * (expy + (absy >> 10)) - 15 + iy; - mx <<= 10 - ix; - my <<= 10 - iy; - int d = expx - expy; - my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; - return half(detail::binary, detail::hypot_post(mx + my, expx)); -#endif -} - -/// Hypotenuse function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). -/// \param x first argument -/// \param y second argument -/// \param z third argument -/// \return square root of sum of squares without internal over- or underflows -/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root -inline half hypot(half x, half y, half z) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t fx = detail::half2float(x.data_), - fy = detail::half2float(y.data_), - fz = detail::half2float(z.data_); - return half(detail::binary, - detail::float2half(std::sqrt(fx * fx + fy * fy + fz * fz))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, - expy = 0, expz = 0; - if(!absx) - return hypot(y, z); - if(!absy) - return hypot(x, z); - if(!absz) - return hypot(x, y); - if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) - return half(detail::binary, - (absx == 0x7C00) - ? detail::select(0x7C00, detail::select(y.data_, z.data_)) - : (absy == 0x7C00) - ? detail::select(0x7C00, detail::select(x.data_, z.data_)) - : (absz == 0x7C00) - ? detail::select(0x7C00, detail::select(x.data_, y.data_)) - : detail::signal(x.data_, y.data_, z.data_)); - if(absz > absy) - std::swap(absy, absz); - if(absy > absx) - std::swap(absx, absy); - if(absz > absy) - std::swap(absy, absz); - for(; absx < 0x400; absx <<= 1, --expx) - ; - for(; absy < 0x400; absy <<= 1, --expy) - ; - for(; absz < 0x400; absz <<= 1, --expz) - ; - detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, - mz = (absz & 0x3FF) | 0x400; - mx *= mx; - my *= my; - mz *= mz; - int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; - expx = 2 * (expx + (absx >> 10)) - 15 + ix; - expy = 2 * (expy + (absy >> 10)) - 15 + iy; - expz = 2 * (expz + (absz >> 10)) - 15 + iz; - mx <<= 10 - ix; - my <<= 10 - iy; - mz <<= 10 - iz; - int d = expy - expz; - mz = (d < 30) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; - my += mz; - if(my & 0x80000000) - { - my = (my >> 1) | (my & 1); - if(++expy > expx) - { - std::swap(mx, my); - std::swap(expx, expy); - } - } - d = expx - expy; - my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; - return half(detail::binary, detail::hypot_post(mx + my, expx)); -#endif -} - -/// Power function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in -/// ~0.00025% of inputs. -/// -/// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). -/// \param x base -/// \param y exponent -/// \return \a x raised to \a y -/// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y -/// is finite and not integral -/// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half pow(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::pow(detail::half2float(x.data_), - detail::half2float(y.data_)))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; - if(!absy || x.data_ == 0x3C00) - return half(detail::binary, - detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); - bool is_int = absy >= 0x6400 || (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); - unsigned int sign = - x.data_ & - (static_cast((absy < 0x6800) && is_int && ((absy >> (25 - (absy >> 10))) & 1)) - << 15); - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absy == 0x7C00) - ? ((absx == 0x3C00) - ? 0x3C00 - : (!absx && y.data_ == 0xFC00) - ? detail::pole() - : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) - : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); - if(!absx) - return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); - if((x.data_ & 0x8000) && !is_int) - return half(detail::binary, detail::invalid()); - if(x.data_ == 0xBC00) - return half(detail::binary, sign | 0x3C00); - if(y.data_ == 0x3800) - return sqrt(x); - if(y.data_ == 0x3C00) - return half(detail::binary, detail::check_underflow(x.data_)); - if(y.data_ == 0x4000) - return x * x; - for(; absx < 0x400; absx <<= 1, --exp) - ; - detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, - m = (((ilog << 27) + - ((detail::log2(static_cast((absx & 0x3FF) | 0x400) << 20) + - 8) >> - 4)) ^ - msign) - - msign; - for(exp = -11; m < 0x80000000; m <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, --exp) - ; - m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) << 21); - int i = m >> 31; - exp += (absy >> 10) + i; - m <<= 1 - i; - if(exp < 0) - { - f = m >> -exp; - exp = 0; - } - else - { - f = (m << exp) & 0x7FFFFFFF; - exp = m >> (31 - exp); - } - return half(detail::binary, - detail::exp2_post( - detail::exp2(f), exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); -#endif -} - -/// \} -/// \anchor trigonometric -/// \name Trigonometric functions -/// \{ - -/// Compute sine and cosine simultaneously. -/// This returns the same results as sin() and cos() but is faster than calling each function -/// individually. -/// -/// This function is exact to rounding for all rounding modes. -/// \param arg function argument -/// \param sin variable to take sine of \a arg -/// \param cos variable to take cosine of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline void sincos(half arg, half* sin, half* cos) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t f = detail::half2float(arg.data_); - *sin = half(detail::binary, detail::float2half(std::sin(f))); - *cos = half(detail::binary, detail::float2half(std::cos(f))); -#else - int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; - if(abs >= 0x7C00) - *sin = *cos = - half(detail::binary, (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - else if(!abs) - { - *sin = arg; - *cos = half(detail::binary, 0x3C00); - } - else if(abs < 0x2500) - { - *sin = half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); - } - else - { - if(half::round_style != std::round_to_nearest) - { - switch(abs) - { - case 0x48B7: - *sin = half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); - *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); - return; - case 0x598C: - *sin = half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); - return; - case 0x6A64: - *sin = half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); - return; - case 0x6D8C: - *sin = half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); - return; - } - } - std::pair sc = - detail::sincos(detail::angle_arg(abs, k), 28); - switch(k & 3) - { - case 1: sc = std::make_pair(sc.second, -sc.first); break; - case 2: sc = std::make_pair(-sc.first, -sc.second); break; - case 3: sc = std::make_pair(-sc.second, sc.first); break; - } - *sin = half(detail::binary, - detail::fixed2half( - (sc.first ^ -static_cast(sign)) + sign)); - *cos = half(detail::binary, - detail::fixed2half(sc.second)); - } -#endif -} - -/// Sine function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). -/// \param arg function argument -/// \return sine value of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half sin(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::sin(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, k; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2900) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - if(half::round_style != std::round_to_nearest) - switch(abs) - { - case 0x48B7: - return half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); - case 0x6A64: - return half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); - case 0x6D8C: - return half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); - } - std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); - detail::uint32 sign = -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); - return half(detail::binary, - detail::fixed2half( - (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); -#endif -} - -/// Cosine function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). -/// \param arg function argument -/// \return cosine value of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half cos(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::cos(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, k; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2500) - return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); - if(half::round_style != std::round_to_nearest && abs == 0x598C) - return half(detail::binary, detail::rounded(0x80FC, 1, 1)); - std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); - detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); - return half(detail::binary, - detail::fixed2half( - (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); -#endif -} - -/// Tangent function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). -/// \param arg function argument -/// \return tangent value of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half tan(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::tan(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = 13, k; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2700) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - if(half::round_style != std::round_to_nearest) - switch(abs) - { - case 0x658C: - return half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x07E6, 1, 1)); - case 0x7330: - return half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x4B62, 1, 1)); - } - std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); - if(k & 1) - sc = std::make_pair(-sc.second, sc.first); - detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); - detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; - for(; my < 0x80000000; my <<= 1, --exp) - ; - for(; mx < 0x80000000; mx <<= 1, ++exp) - ; - return half( - detail::binary, - detail::tangent_post(my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); -#endif -} - -/// Arc sine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). -/// \param arg function argument -/// \return arc sine value of \a arg -/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half asin(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::asin(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(!abs) - return arg; - if(abs >= 0x3C00) - return half(detail::binary, - (abs > 0x7C00) - ? detail::signal(arg.data_) - : (abs > 0x3C00) - ? detail::invalid() - : detail::rounded(sign | 0x3E48, 0, 1)); - if(abs < 0x2900) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) - return half(detail::binary, detail::rounded(arg.data_ + 1, 1, 1)); - std::pair sc = detail::atan2_args(abs); - detail::uint32 m = - detail::atan2(sc.first, sc.second, (half::round_style == std::round_to_nearest) ? 27 : 26); - return half(detail::binary, - detail::fixed2half(m, 14, sign)); -#endif -} - -/// Arc cosine function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). -/// \param arg function argument -/// \return arc cosine value of \a arg -/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half acos(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::acos(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; - if(!abs) - return half(detail::binary, detail::rounded(0x3E48, 0, 1)); - if(abs >= 0x3C00) - return half(detail::binary, - (abs > 0x7C00) - ? detail::signal(arg.data_) - : (abs > 0x3C00) - ? detail::invalid() - : sign ? detail::rounded(0x4248, 0, 1) : 0); - std::pair cs = detail::atan2_args(abs); - detail::uint32 m = detail::atan2(cs.second, cs.first, 28); - return half(detail::binary, - detail::fixed2half( - sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); -#endif -} - -/// Arc tangent function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). -/// \param arg function argument -/// \return arc tangent value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half atan(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::atan(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::rounded(sign | 0x3E48, 0, 1) - : detail::signal(arg.data_)); - if(abs <= 0x2700) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - int exp = (abs >> 10) + (abs <= 0x3FF); - detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); - detail::uint32 m = (exp > 15) - ? detail::atan2(my << 19, - 0x20000000 >> (exp - 15), - (half::round_style == std::round_to_nearest) ? 26 : 24) - : detail::atan2(my << (exp + 4), - 0x20000000, - (half::round_style == std::round_to_nearest) ? 30 : 28); - return half(detail::binary, - detail::fixed2half(m, 14, sign)); -#endif -} - -/// Arc tangent function. -/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for -/// `std::round_to_nearest`, -/// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding -/// mode. -/// -/// **See also:** Documentation for -/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). -/// \param y numerator -/// \param x denominator -/// \return arc tangent value -/// \exception FE_INVALID if \a x or \a y is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half atan2(half y, half x) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::atan2(detail::half2float(y.data_), - detail::half2float(x.data_)))); -#else - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, - signy = y.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - { - if(absx > 0x7C00 || absy > 0x7C00) - return half(detail::binary, detail::signal(x.data_, y.data_)); - if(absy == 0x7C00) - return half(detail::binary, - (absx < 0x7C00) - ? detail::rounded(signy | 0x3E48, 0, 1) - : signx - ? detail::rounded(signy | 0x40B6, 0, 1) - : detail::rounded(signy | 0x3A48, 0, 1)); - return (x.data_ == 0x7C00) - ? half(detail::binary, signy) - : half(detail::binary, - detail::rounded(signy | 0x4248, 0, 1)); - } - if(!absy) - return signx ? half(detail::binary, - detail::rounded(signy | 0x4248, 0, 1)) - : y; - if(!absx) - return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); - int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); - if(d > (signx ? 18 : 12)) - return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); - if(signx && d < -11) - return half(detail::binary, detail::rounded(signy | 0x4248, 0, 1)); - if(!signx && d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) - { - for(; absy < 0x400; absy <<= 1, --d) - ; - detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, my = ((absy << 1) & 0x7FF) | 0x800; - int i = my < mx; - d -= i; - if(d < -25) - return half(detail::binary, detail::underflow(signy)); - my <<= 11 + i; - return half(detail::binary, - detail::fixed2half( - my / mx, d + 14, signy, my % mx != 0)); - } - detail::uint32 m = detail::atan2( - ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d : (d > 0) ? 0 : -1)), - ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d : (d < 0) ? 0 : 1))); - return half(detail::binary, - detail::fixed2half( - signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); -#endif -} - -/// \} -/// \anchor hyperbolic -/// \name Hyperbolic functions -/// \{ - -/// Hyperbolic sine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). -/// \param arg function argument -/// \return hyperbolic sine value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half sinh(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::sinh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp; - if(!abs || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - if(abs <= 0x2900) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - std::pair mm = - detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); - detail::uint32 m = mm.first - mm.second; - for(exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) - ; - unsigned int sign = arg.data_ & 0x8000; - if(exp > 29) - return half(detail::binary, detail::overflow(sign)); - return half(detail::binary, - detail::fixed2half(m, exp, sign)); -#endif -} - -/// Hyperbolic cosine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). -/// \param arg function argument -/// \return hyperbolic cosine value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half cosh(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::cosh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); - std::pair mm = - detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); - detail::uint32 m = mm.first + mm.second, i = (~m & 0xFFFFFFFF) >> 31; - m = (m >> i) | (m & i) | 0x80000000; - if((exp += 13 + i) > 29) - return half(detail::binary, detail::overflow()); - return half(detail::binary, - detail::fixed2half(m, exp)); -#endif -} - -/// Hyperbolic tangent. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). -/// \param arg function argument -/// \return hyperbolic tangent value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half tanh(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::tanh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs > 0x7C00) ? detail::signal(arg.data_) : (arg.data_ - 0x4000)); - if(abs >= 0x4500) - return half(detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); - if(abs < 0x2700) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - if(half::round_style != std::round_to_nearest && abs == 0x2D3F) - return half(detail::binary, detail::rounded(arg.data_ - 3, 0, 1)); - std::pair mm = detail::hyperbolic_args(abs, exp, 27); - detail::uint32 my = mm.first - mm.second - (half::round_style != std::round_to_nearest), - mx = mm.first + mm.second, i = (~mx & 0xFFFFFFFF) >> 31; - for(exp = 13; my < 0x80000000; my <<= 1, --exp) - ; - mx = (mx >> i) | 0x80000000; - return half(detail::binary, - detail::tangent_post(my, mx, exp - i, arg.data_ & 0x8000)); -#endif -} - -/// Hyperbolic area sine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). -/// \param arg function argument -/// \return area sine value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half asinh(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::asinh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(!abs || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - if(abs <= 0x2900) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - if(half::round_style != std::round_to_nearest) - switch(abs) - { - case 0x32D4: - return half(detail::binary, - detail::rounded(arg.data_ - 13, 1, 1)); - case 0x3B5B: - return half(detail::binary, - detail::rounded(arg.data_ - 197, 1, 1)); - } - return half(detail::binary, detail::area(arg.data_)); -#endif -} - -/// Hyperbolic area cosine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). -/// \param arg function argument -/// \return area cosine value of \a arg -/// \exception FE_INVALID for signaling NaN or arguments <1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half acosh(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::acosh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if((arg.data_ & 0x8000) || abs < 0x3C00) - return half(detail::binary, - (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs == 0x3C00) - return half(detail::binary, 0); - if(arg.data_ >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - return half(detail::binary, detail::area(arg.data_)); -#endif -} - -/// Hyperbolic area tangent. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). -/// \param arg function argument -/// \return area tangent value of \a arg -/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 -/// \exception FE_DIVBYZERO for +/-1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half atanh(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::atanh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = 0; - if(!abs) - return arg; - if(abs >= 0x3C00) - return half(detail::binary, - (abs == 0x3C00) - ? detail::pole(arg.data_ & 0x8000) - : (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2700) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - detail::uint32 m = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) - << ((abs >> 10) + (abs <= 0x3FF) + 6), - my = 0x80000000 + m, mx = 0x80000000 - m; - for(; mx < 0x80000000; mx <<= 1, ++exp) - ; - int i = my >= mx, s; - return half(detail::binary, - detail::log2_post( - detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, - exp + i - 1, - 16, - arg.data_ & 0x8000)); -#endif -} - -/// \} -/// \anchor special -/// \name Error and gamma functions -/// \{ - -/// Error function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% -/// of inputs. -/// -/// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). -/// \param arg function argument -/// \return error function value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half erf(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::erf(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF; - if(!abs || abs >= 0x7C00) - return (abs >= 0x7C00) - ? half(detail::binary, - (abs == 0x7C00) ? (arg.data_ - 0x4000) : detail::signal(arg.data_)) - : arg; - if(abs >= 0x4200) - return half(detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); - return half(detail::binary, detail::erf(arg.data_)); -#endif -} - -/// Complementary error function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% -/// of inputs. -/// -/// **See also:** Documentation for -/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). -/// \param arg function argument -/// \return 1 minus error function value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half erfc(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::erfc(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(abs >= 0x7C00) - return (abs >= 0x7C00) - ? half(detail::binary, (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) - : arg; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x4400) - return half( - detail::binary, - detail::rounded((sign >> 1) - (sign >> 15), sign >> 15, 1)); - return half(detail::binary, detail::erf(arg.data_)); -#endif -} - -/// Natural logarithm of gamma function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in -/// ~0.025% of inputs. -/// -/// **See also:** Documentation for -/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). -/// \param arg function argument -/// \return natural logarith of gamma function for \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_DIVBYZERO for 0 or negative integer arguments -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half lgamma(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::lgamma(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(abs >= 0x7C00) - return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); - if(!abs || arg.data_ >= 0xE400 || - (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) - return half(detail::binary, detail::pole()); - if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) - return half(detail::binary, 0); - return half(detail::binary, detail::gamma(arg.data_)); -#endif -} - -/// Gamma function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in -/// <0.25% of inputs. -/// -/// **See also:** Documentation for -/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). -/// \param arg function argument -/// \return gamma function value of \a arg -/// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half tgamma(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::tgamma(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF; - if(!abs) - return half(detail::binary, detail::pole(arg.data_)); - if(abs >= 0x7C00) - return (arg.data_ == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) - return half(detail::binary, detail::invalid()); - if(arg.data_ >= 0xCA80) - return half( - detail::binary, - detail::underflow((1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); - if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) - return half(detail::binary, detail::overflow()); - if(arg.data_ == 0x3C00) - return arg; - return half(detail::binary, detail::gamma(arg.data_)); -#endif -} - -/// \} -/// \anchor rounding -/// \name Rounding -/// \{ - -/// Nearest integer not less than half value. -/// **See also:** Documentation for -/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). -/// \param arg half to round -/// \return nearest integer not less than \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half ceil(half arg) -{ - return half(detail::binary, - detail::integral(arg.data_)); -} - -/// Nearest integer not greater than half value. -/// **See also:** Documentation for -/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). -/// \param arg half to round -/// \return nearest integer not greater than \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half floor(half arg) -{ - return half(detail::binary, - detail::integral(arg.data_)); -} - -/// Nearest integer not greater in magnitude than half value. -/// **See also:** Documentation for -/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). -/// \param arg half to round -/// \return nearest integer not greater in magnitude than \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half trunc(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} - -/// Nearest integer. -/// **See also:** Documentation for -/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). -/// \param arg half to round -/// \return nearest integer, rounded away from zero in half-way cases -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half round(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} - -/// Nearest integer. -/// **See also:** Documentation for -/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). -/// \param arg half to round -/// \return nearest integer, rounded away from zero in half-way cases -/// \exception FE_INVALID if value is not representable as `long` -inline long lround(half arg) -{ - return detail::half2int(arg.data_); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half rint(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID if value is not representable as `long` -/// \exception FE_INEXACT if value had to be rounded -inline long lrint(half arg) -{ - return detail::half2int(arg.data_); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID for signaling NaN -inline half nearbyint(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} -#if HALF_ENABLE_CPP11_LONG_LONG -/// Nearest integer. -/// **See also:** Documentation for -/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). -/// \param arg half to round -/// \return nearest integer, rounded away from zero in half-way cases -/// \exception FE_INVALID if value is not representable as `long long` -inline long long llround(half arg) -{ - return detail::half2int(arg.data_); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID if value is not representable as `long long` -/// \exception FE_INEXACT if value had to be rounded -inline long long llrint(half arg) -{ - return detail::half2int(arg.data_); -} -#endif - -/// \} -/// \anchor float -/// \name Floating point manipulation -/// \{ - -/// Decompress floating-point number. -/// **See also:** Documentation for -/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). -/// \param arg number to decompress -/// \param exp address to store exponent at -/// \return significant in range [0.5, 1) -/// \exception FE_INVALID for signaling NaN -inline half frexp(half arg, int* exp) -{ - *exp = 0; - unsigned int abs = arg.data_ & 0x7FFF; - if(abs >= 0x7C00 || !abs) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --*exp) - ; - *exp += (abs >> 10) - 14; - return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); -} - -/// Multiply by power of two. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). -/// \param arg number to modify -/// \param exp power of two to multiply with -/// \return \a arg multplied by 2 raised to \a exp -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half scalbln(half arg, long exp) -{ - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(abs >= 0x7C00 || !abs) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - if(exp > 30) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -10) - return half(detail::binary, detail::underflow(sign)); - else if(exp > 0) - return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); - unsigned int m = (abs & 0x3FF) | 0x400; - return half(detail::binary, - detail::rounded( - sign | (m >> (1 - exp)), (m >> -exp) & 1, (m & ((1 << -exp) - 1)) != 0)); -} - -/// Multiply by power of two. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). -/// \param arg number to modify -/// \param exp power of two to multiply with -/// \return \a arg multplied by 2 raised to \a exp -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } - -/// Multiply by power of two. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). -/// \param arg number to modify -/// \param exp power of two to multiply with -/// \return \a arg multplied by 2 raised to \a exp -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } - -/// Extract integer and fractional parts. -/// **See also:** Documentation for -/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). -/// \param arg number to decompress -/// \param iptr address to store integer part at -/// \return fractional part -/// \exception FE_INVALID for signaling NaN -inline half modf(half arg, half* iptr) -{ - unsigned int abs = arg.data_ & 0x7FFF; - if(abs > 0x7C00) - { - arg = half(detail::binary, detail::signal(arg.data_)); - return *iptr = arg, arg; - } - if(abs >= 0x6400) - return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); - if(abs < 0x3C00) - return iptr->data_ = arg.data_ & 0x8000, arg; - unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, m = arg.data_ & mask; - iptr->data_ = arg.data_ & ~mask; - if(!m) - return half(detail::binary, arg.data_ & 0x8000); - for(; m < 0x400; m <<= 1, --exp) - ; - return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); -} - -/// Extract exponent. -/// **See also:** Documentation for -/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). -/// \param arg number to query -/// \return floating-point exponent -/// \retval FP_ILOGB0 for zero -/// \retval FP_ILOGBNAN for NaN -/// \retval INT_MAX for infinity -/// \exception FE_INVALID for 0 or infinite values -inline int ilogb(half arg) -{ - int abs = arg.data_ & 0x7FFF, exp; - if(!abs || abs >= 0x7C00) - { - detail::raise(FE_INVALID); - return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; - } - for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) - ; - return exp; -} - -/// Extract exponent. -/// **See also:** Documentation for -/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). -/// \param arg number to query -/// \return floating-point exponent -/// \exception FE_INVALID for signaling NaN -/// \exception FE_DIVBYZERO for 0 -inline half logb(half arg) -{ - int abs = arg.data_ & 0x7FFF, exp; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(abs >= 0x7C00) - return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); - for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) - ; - unsigned int value = static_cast(exp < 0) << 15; - if(exp) - { - unsigned int m = std::abs(exp) << 6; - for(exp = 18; m < 0x400; m <<= 1, --exp) - ; - value |= (exp << 10) + m; - } - return half(detail::binary, value); -} - -/// Next representable value. -/// **See also:** Documentation for -/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). -/// \param from value to compute next representable value for -/// \param to direction towards which to compute next value -/// \return next representable value after \a from in direction towards \a to -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW for infinite result from finite argument -/// \exception FE_UNDERFLOW for subnormal result -inline half nextafter(half from, half to) -{ - int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; - if(fabs > 0x7C00 || tabs > 0x7C00) - return half(detail::binary, detail::signal(from.data_, to.data_)); - if(from.data_ == to.data_ || !(fabs | tabs)) - return to; - if(!fabs) - { - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); - return half(detail::binary, (to.data_ & 0x8000) + 1); - } - unsigned int out = - from.data_ + - (((from.data_ >> 15) ^ - static_cast((from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < - (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) - << 1) - - 1; - detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7C00) < 0x400); - return half(detail::binary, out); -} - -/// Next representable value. -/// **See also:** Documentation for -/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). -/// \param from value to compute next representable value for -/// \param to direction towards which to compute next value -/// \return next representable value after \a from in direction towards \a to -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW for infinite result from finite argument -/// \exception FE_UNDERFLOW for subnormal result -inline half nexttoward(half from, long double to) -{ - int fabs = from.data_ & 0x7FFF; - if(fabs > 0x7C00) - return half(detail::binary, detail::signal(from.data_)); - long double lfrom = static_cast(from); - if(detail::builtin_isnan(to) || lfrom == to) - return half(static_cast(to)); - if(!fabs) - { - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); - return half(detail::binary, (static_cast(detail::builtin_signbit(to)) << 15) + 1); - } - unsigned int out = - from.data_ + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; - detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7FFF) < 0x400); - return half(detail::binary, out); -} - -/// Take sign. -/// **See also:** Documentation for -/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). -/// \param x value to change sign for -/// \param y value to take sign from -/// \return value equal to \a x in magnitude and to \a y in sign -inline HALF_CONSTEXPR half copysign(half x, half y) -{ - return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); -} - -/// \} -/// \anchor classification -/// \name Floating point classification -/// \{ - -/// Classify floating-point value. -/// **See also:** Documentation for -/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). -/// \param arg number to classify -/// \retval FP_ZERO for positive and negative zero -/// \retval FP_SUBNORMAL for subnormal numbers -/// \retval FP_INFINITY for positive and negative infinity -/// \retval FP_NAN for NaNs -/// \retval FP_NORMAL for all other (normal) values -inline HALF_CONSTEXPR int fpclassify(half arg) -{ - return !(arg.data_ & 0x7FFF) - ? FP_ZERO - : ((arg.data_ & 0x7FFF) < 0x400) - ? FP_SUBNORMAL - : ((arg.data_ & 0x7FFF) < 0x7C00) - ? FP_NORMAL - : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN; -} - -/// Check if finite number. -/// **See also:** Documentation for -/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). -/// \param arg number to check -/// \retval true if neither infinity nor NaN -/// \retval false else -inline HALF_CONSTEXPR bool isfinite(half arg) { return (arg.data_ & 0x7C00) != 0x7C00; } - -/// Check for infinity. -/// **See also:** Documentation for -/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). -/// \param arg number to check -/// \retval true for positive or negative infinity -/// \retval false else -inline HALF_CONSTEXPR bool isinf(half arg) { return (arg.data_ & 0x7FFF) == 0x7C00; } - -/// Check for NaN. -/// **See also:** Documentation for -/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). -/// \param arg number to check -/// \retval true for NaNs -/// \retval false else -inline HALF_CONSTEXPR bool isnan(half arg) { return (arg.data_ & 0x7FFF) > 0x7C00; } - -/// Check if normal number. -/// **See also:** Documentation for -/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). -/// \param arg number to check -/// \retval true if normal number -/// \retval false if either subnormal, zero, infinity or NaN -inline HALF_CONSTEXPR bool isnormal(half arg) -{ - return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); -} - -/// Check sign. -/// **See also:** Documentation for -/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). -/// \param arg number to check -/// \retval true for negative number -/// \retval false for positive number -inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_ & 0x8000) != 0; } - -/// \} -/// \anchor compfunc -/// \name Comparison -/// \{ - -/// Quiet comparison for greater than. -/// **See also:** Documentation for -/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater than \a y -/// \retval false else -inline HALF_CONSTEXPR bool isgreater(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comparison for greater equal. -/// **See also:** Documentation for -/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater equal \a y -/// \retval false else -inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comparison for less than. -/// **See also:** Documentation for -/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less than \a y -/// \retval false else -inline HALF_CONSTEXPR bool isless(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comparison for less equal. -/// **See also:** Documentation for -/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less equal \a y -/// \retval false else -inline HALF_CONSTEXPR bool islessequal(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comarison for less or greater. -/// **See also:** Documentation for -/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). -/// \param x first operand -/// \param y second operand -/// \retval true if either less or greater -/// \retval false else -inline HALF_CONSTEXPR bool islessgreater(half x, half y) -{ - return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && !isnan(y); -} - -/// Quiet check if unordered. -/// **See also:** Documentation for -/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). -/// \param x first operand -/// \param y second operand -/// \retval true if unordered (one or two NaN operands) -/// \retval false else -inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } - -/// \} -/// \anchor casting -/// \name Casting -/// \{ - -/// Cast to or from half-precision floating-point number. -/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values -/// are converted -/// directly using the default rounding mode, without any roundtrip over `float` that a -/// `static_cast` would otherwise do. -/// -/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any -/// of the two types -/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) -/// results in a compiler -/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. -/// \tparam T destination type (half or built-in arithmetic type) -/// \tparam U source type (half or built-in arithmetic type) -/// \param arg value to cast -/// \return \a arg converted to destination type -/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -template -T half_cast(U arg) -{ - return detail::half_caster::cast(arg); -} - -/// Cast to or from half-precision floating-point number. -/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values -/// are converted -/// directly using the specified rounding mode, without any roundtrip over `float` that a -/// `static_cast` would otherwise do. -/// -/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any -/// of the two types -/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) -/// results in a compiler -/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. -/// \tparam T destination type (half or built-in arithmetic type) -/// \tparam R rounding mode to use. -/// \tparam U source type (half or built-in arithmetic type) -/// \param arg value to cast -/// \return \a arg converted to destination type -/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -template -T half_cast(U arg) -{ - return detail::half_caster::cast(arg); -} -/// \} - -/// \} -/// \anchor errors -/// \name Error handling -/// \{ - -/// Clear exception flags. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). -/// \param excepts OR of exceptions to clear -/// \retval 0 all selected flags cleared successfully -inline int feclearexcept(int excepts) -{ - detail::errflags() &= ~excepts; - return 0; -} - -/// Test exception flags. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). -/// \param excepts OR of exceptions to test -/// \return OR of selected exceptions if raised -inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } - -/// Raise exception flags. -/// This raises the specified floating point exceptions and also invokes any additional automatic -/// exception handling as -/// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). -/// \param excepts OR of exceptions to raise -/// \retval 0 all selected exceptions raised successfully -inline int feraiseexcept(int excepts) -{ - detail::errflags() |= excepts; - detail::raise(excepts); - return 0; -} - -/// Save exception flags. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). -/// \param flagp adress to store flag state at -/// \param excepts OR of flags to save -/// \retval 0 for success -inline int fegetexceptflag(int* flagp, int excepts) -{ - *flagp = detail::errflags() & excepts; - return 0; -} - -/// Restore exception flags. -/// This only copies the specified exception state (including unset flags) without incurring any -/// additional exception handling. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). -/// \param flagp adress to take flag state from -/// \param excepts OR of flags to restore -/// \retval 0 for success -inline int fesetexceptflag(const int* flagp, int excepts) -{ - detail::errflags() = (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); - return 0; -} - -/// Throw C++ exceptions based on set exception flags. -/// This function manually throws a corresponding C++ exception if one of the specified flags is -/// set, -/// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref -/// HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// \param excepts OR of exceptions to test -/// \param msg error message to use for exception description -/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set -/// \throw std::overflow_error if `FE_OVERFLOW` is selected and set -/// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set -/// \throw std::range_error if `FE_INEXACT` is selected and set -inline void fethrowexcept(int excepts, const char* msg = "") -{ - excepts &= detail::errflags(); - if(excepts & (FE_INVALID | FE_DIVBYZERO)) - throw std::domain_error(msg); - if(excepts & FE_OVERFLOW) - throw std::overflow_error(msg); - if(excepts & FE_UNDERFLOW) - throw std::underflow_error(msg); - if(excepts & FE_INEXACT) - throw std::range_error(msg); -} -/// \} -} // namespace half_float - -#undef HALF_UNUSED_NOERR -#undef HALF_CONSTEXPR -#undef HALF_CONSTEXPR_CONST -#undef HALF_CONSTEXPR_NOERR -#undef HALF_NOEXCEPT -#undef HALF_NOTHROW -#undef HALF_THREAD_LOCAL -#undef HALF_TWOS_COMPLEMENT_INT -#ifdef HALF_POP_WARNINGS -#pragma warning(pop) -#undef HALF_POP_WARNINGS -#endif - -#endif diff --git a/include/ck/config.hpp b/include/ck/ck.hpp similarity index 91% rename from include/ck/config.hpp rename to include/ck/ck.hpp index 1cdde85c42d99e433092fb690521a1bf014e3975..1e68e018a725e48de177e077600becc66f46efc0 100644 --- a/include/ck/config.hpp +++ b/include/ck/ck.hpp @@ -1,5 +1,7 @@ -#ifndef CK_CONFIG_AMD_HPP -#define CK_CONFIG_AMD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include "ck/options.hpp" @@ -11,6 +13,8 @@ #include "hip/hip_fp16.h" #endif +#define CK_TIME_KERNEL 1 + // constant address space for kernel parameter // https://llvm.org/docs/AMDGPUUsage.html#address-spaces #define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) @@ -109,7 +113,12 @@ #define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 // experimental feature: buffer load/store/atomic-add/ OOB trick +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting. Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter for each usage +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 @@ -150,10 +159,6 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 1 -// workaround for verification failure ConvNd forward -// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 -#define CK_WORKAROUND_GITHUB_135 1 - #ifndef CK_USE_X86_INLINE_ASM #define CK_USE_X86_INLINE_ASM 1 #endif @@ -168,6 +173,7 @@ enum struct InMemoryDataOperationEnum Add }; +// FIXME: use regular Sequence and remove this template struct InMemoryDataOperationEnumSequence { @@ -181,17 +187,8 @@ struct InMemoryDataOperationEnumSequence } }; -// TODO: no longer needed, remove this -enum struct ActivTypeEnum -{ - None, - LeakyRelu, - Sigmoid -}; - // index type using index_t = int32_t; using long_index_t = int64_t; } // namespace ck -#endif diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/device_utility/device_prop.hpp similarity index 90% rename from include/ck/host_utility/device_prop.hpp rename to include/ck/device_utility/device_prop.hpp index 74b20acecd3367317b20f8ed2f0f4aed49f5220f..e2cbdb733272d37aa5dbc7a746e86911d6b8644f 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/device_utility/device_prop.hpp @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include +#include namespace ck { diff --git a/include/ck/device_utility/hip_check_error.hpp b/include/ck/device_utility/hip_check_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3dc8eaf1eb8b87207256ea4a521e23d5a49ca9c --- /dev/null +++ b/include/ck/device_utility/hip_check_error.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ + << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} diff --git a/include/ck/device_utility/kernel_launch.hpp b/include/ck/device_utility/kernel_launch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5879f9995e0bb16d31cc9ce3e7be9a3d82da0e32 --- /dev/null +++ b/include/ck/device_utility/kernel_launch.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/device_utility/hip_check_error.hpp" + +template +float launch_and_time_kernel(const StreamConfig& stream_config, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + const int nrepeat = 10; + + printf("Warm up 1 time\n"); + + // warm up + kernel<<>>(args...); + + printf("Start running %d times...\n", nrepeat); + + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + kernel<<>>(args...); + + return 0; + } +#else + kernel<<>>(args...); + + return 0; +#endif +} diff --git a/include/ck/host_utility/xdnn_desc.hpp b/include/ck/device_utility/xdnn_desc.hpp similarity index 100% rename from include/ck/host_utility/xdnn_desc.hpp rename to include/ck/device_utility/xdnn_desc.hpp diff --git a/include/ck/options.hpp b/include/ck/options.hpp deleted file mode 100644 index 82c604f82badc8cf9a9333ae0b5e3e94eff53e3a..0000000000000000000000000000000000000000 --- a/include/ck/options.hpp +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once - -#define CK_TIME_KERNEL 1 diff --git a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp index af682ecfa7ed30b57b92eb62c0358232570671d6..db8e48df6d4b18e9c8a60b8f3d4382d1c6e649f8 100644 --- a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index 6693c0756b97102802c41c1bbd1b4a725b273354..5391b595b5c1e9b5b7395d1a676efcd87b9f2e6e 100644 --- a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp index e533ad918843c26d7e2f3b5a6e6f1665a11c4900..bb1dc239f4d7e6a5abade2dda4a17a98ab7d1a47 100644 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp index 949f044b7dda593b2829d7713da81245cdfa4363..ca530934e49dcccea388fd8501e5bf046f1c94af 100644 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp index 213e1d613515141d2fcd5103721af9590bf09fd2..e960f90c4bb1974ddfaaf9faf6a12d75d1db2f4d 100644 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index f1e1826d1621ec18544b122308261fb17e6d3914..052bab423db04740c88721026a3cceb3fa77c292 100644 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp index 02e61c0ea3eff692a74a176c736b5c47dd3bad1e..c301a9e0c67b11c42fbe2a0a99c4f764049a5222 100644 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp index 7544289b21884de30877c6af85c7f6fe12bed261..41267536551ea298e111f6d93b8e6f3f8f9ed475 100644 --- a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp index 093a46256d7762ac245b9b7695164d33567cb9f2..381f9ac9d6f5a9b000eb44839c7198e4dc9f5d9b 100644 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 9aa27884da5b525062ba20faf3af60f1da322c7c..ebfaabb03ebdda41d086e4b80c5856156810992a 100644 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp index 16ae8b470dad06a56753a0a0e361e1db3be83338..6e576d69f5f4b5a9a60a05f84edb43fec990e91a 100644 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp index e81c87d046f2d6f3e34f6687935a5d2b47cbabaa..13e1bf251abfdc3716bdd17f331f82fd7bdfad41 100644 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index ac90e8a6ffa8067d1fb36a24eb8fb92a7125734e..088d14b2ee472bc049385ad27e45d02df1e3b6c9 100644 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp index f5cb7f487704b4470b0235e34cfbec1c748a13d9..a6785d56df7e70fdb4882f0c431b886d8c87caac 100644 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 342a2b1c83f8f778e8e55ea883f492ffc2eb2814..560d61047fce9eae299712a6554e3b21bb0626c5 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #ifndef CK_NOGPU diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index 681b7e9a8d3f4925138fd9660b5451fe0f23c1e5..e260efd56419bf0570d1ae30c794779b9aca3a7d 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP diff --git a/include/ck/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp index d69bfb70c1e0365f22c7eddd7563bf5ec2c7e078..0c9ea2ff2a0d73b793008a954eaf7293b33ade08 100644 --- a/include/ck/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -1,8 +1,10 @@ -#ifndef CK_CLUSTER_DESCRIPTOR_HPP -#define CK_CLUSTER_DESCRIPTOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_adaptor.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { @@ -30,4 +32,3 @@ __host__ __device__ constexpr auto make_cluster_descriptor( } } // namespace ck -#endif diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index fa705cc3feea7f0cd62185f0522775a624bc8764..4e4d7593e9083f321f1cabe65faaee3c14259666 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,8 +1,10 @@ -#ifndef CK_MULTI_INDEX_TRANSFORM_HPP -#define CK_MULTI_INDEX_TRANSFORM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "multi_index.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/multi_index.hpp" namespace ck { @@ -1950,4 +1952,3 @@ struct Modulo } }; } // namespace ck -#endif diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index bc360714b992b59020779bdf95fd0b2e7ac86154..044a90370095eb53b94b5c2fba81abdbeae82c00 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,8 +1,10 @@ -#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP -#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "multi_index_transform.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" namespace ck { @@ -126,4 +128,3 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, return Modulo{modulus, up_length}; } } // namespace ck -#endif diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 8787abd6ba67c7668c6f79c53fcda376b13eb706..d42e0a6ff08f60eb1cf6e0f241f93e528bf3514b 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -1,9 +1,11 @@ -#ifndef CK_TENSOR_ADAPTOR_HPP -#define CK_TENSOR_ADAPTOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" namespace ck { @@ -136,7 +138,11 @@ struct TensorAdaptor using ElementSize = remove_cv_t; public: +#if 0 // workaround compiler complaint about constexpr __host__ __device__ constexpr TensorAdaptor() = default; +#else + __host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {} +#endif __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} @@ -474,4 +480,3 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. } } // namespace ck -#endif diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 49a884106fc560da7d9a4f6d0cff1febeb3e0994..76869939ddd30212a5afa12388b8a3db22850328 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -1,8 +1,10 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HPP -#define CK_TENSOR_DESCRIPTOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "multi_index_transform.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" namespace ck { @@ -111,7 +113,14 @@ struct TensorDescriptor using ElementSize = remove_cv_t; public: +#if 0 // workaround compiler complaint about constexpr __host__ __device__ constexpr TensorDescriptor() = default; +#else + __host__ __device__ constexpr TensorDescriptor() + : transforms_{}, element_size_{}, element_space_size_{} + { + } +#endif __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, ElementSpaceSize element_space_size) @@ -602,4 +611,3 @@ using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck -#endif diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index ddc0ede404dfccbc465cea6de1516e4d146795dc..461aae72cf7b1879c90e473adb3814e1c4875b52 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "multi_index_transform_helper.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" namespace ck { diff --git a/include/ck/utility/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp similarity index 93% rename from include/ck/utility/tensor_space_filling_curve.hpp rename to include/ck/tensor_description/tensor_space_filling_curve.hpp index b5f1a34d8370d7ddb29750178c6e05dd86f0aaeb..e9a990d857caa6e4d11c4473c617164ccc2d38a3 100644 --- a/include/ck/utility/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -1,12 +1,14 @@ -#ifndef TENSOR_SPACE_FILLING_CURVE_HPP -#define TENSOR_SPACE_FILLING_CURVE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "math.hpp" -#include "sequence.hpp" -#include "sequence_helper.hpp" -#include "tensor_adaptor.hpp" -#include "statically_indexed_array_multi_index.hpp" -#include "tuple_helper.hpp" +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/statically_indexed_array_multi_index.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { @@ -156,4 +158,3 @@ struct SpaceFillingCurve }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index f7fa867e1629ffdef5cb08dcdf5a2e1304401a48..8b1b7be11ef7b3cd19f11bd400b8a96146921bf7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_adaptor.hpp" -#include "threadwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_contraction_dl.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp index 2a8a4bc8b882b7ca148feb27fb9430d7fc7ef4b8..33120bd86ff01da47f7b80c593bc966785cb3711 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp index 78cfc1e0fbf1c2a661fe50dade2bb15aa8d9f85a..f45655721fe4de11b50b91e0dc5d22790ff73bea 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index a989cb5297a4ea5337fa934748e455ba36630fe5..9720db4a954c522204c02d47f30303ccfd14329f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,9 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "xdlops_gemm.hpp" -#include "tensor_adaptor.hpp" -#include "thread_group.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { @@ -438,7 +441,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(n0, I0, I0, I0), b_thread_buf); }); - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except // the first, as we can shorten non-MAC cluster a bit and there's no observable negative // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids @@ -448,7 +451,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) { asm volatile("s_barrier" ::); - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -480,9 +483,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && n0.value == NRepeat - 1) { - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); block_sync_lds(); - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); } // TODO: insert setprio in more precise manner since we @@ -493,16 +496,16 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); } }); }); }); - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(0); - __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_sched_barrier(0); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index e8ec1643640da2322f8411ddd550c5f80c278b8e..03e4d42d3a1f9eb1b180c95368b905e619e67110 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -1,11 +1,13 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v5r1.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp" namespace ck { @@ -152,4 +154,3 @@ struct BlockwiseTensorSliceTransfer_v5r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index cc452b5e5cae505ffa18ddd8ad976a7d05fddd09..cce560367f3894407194b4128af29925e6175a53 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -1,35 +1,11 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP -#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP - -#include "reduction_common.hpp" -#include "reduction_functions_accumulate.hpp" - -#include "cluster_descriptor.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" namespace ck { @@ -45,7 +21,9 @@ template + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithNanCheck> struct PartitionedBlockwiseReduction { static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), @@ -62,8 +40,6 @@ struct PartitionedBlockwiseReduction static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using Accumulation = detail::AccumulateWithNanCheck; - template __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) { @@ -113,13 +89,16 @@ struct PartitionedBlockwiseReduction // 3) in_out_value/in_out_index is the input data in vgpr from each thread // 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread // clang-format on -template +template < + typename AccDataType, + typename IndexDataType, + index_t BlockSize, + typename ThreadClusterLengths_M_K, + typename ThreadClusterArrangeOrder, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> struct PartitionedBlockwiseReductionWithIndex { static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), @@ -136,9 +115,6 @@ struct PartitionedBlockwiseReductionWithIndex static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using Accumulation = - detail::AccumulateWithIndexAndNanCheck; - // This interface accumulates on both data values and indices template __device__ static void Reduce(BufferType& work_val_buffer, @@ -193,6 +169,4 @@ struct PartitionedBlockwiseReductionWithIndex }; }; -}; // end of namespace ck - -#endif +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index cbabbaf47dfcb085b7f4b9811a24ca66e119660c..0e5dfb355fb00565f8d9834a2e0df05bfd9db416 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v3r1.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp index 1f0ad3e35af0e38ea3189408731678cda7ce60fa..5c47a49b38b6a30eebc9190cc338d4dbc0bc8524 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v6r1.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp index 121ddf12ad93b8054b3c4bc06842ef88d4b9fc61..aa33fc083f15cfd8904c8fa086a866dbc7817e7c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v6r2.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp index ca5db90f30710fc519733afae8c7c1c37471ffa8..eb5f589a4ada976c9be4a5001fb6fc288c7e9a43 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v6r3.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3bd7806389b59542b191e61f5369aebf7300fc6a --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs); + } + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp index eae1bf9f8eea380720f651ea98751f41e59e2a82..6a226b0c53af923410ae4d2117b8e18e81e2d17b 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 60995e068ce9d43634de97705f73e311bc53311a..f4607ee6124e4af09a7fa3884776ee2b16a433df 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once namespace ck { diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index c9eaf64d667bfa2e6c00d6dd94542b7eea7bba25..ea60a4e6d90557fd1df10e6d45c8a4e75ec41d74 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION @@ -15,7 +18,7 @@ enum struct ConvolutionForwardSpecialization OddC, }; -inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s) +inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) { switch(s) { diff --git a/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp index 6ca0790ce4e658dcf90f86fd3031930dcde5db28..c228045bdbce261ec2dea9f95d265b60a7016538 100644 --- a/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp @@ -1,13 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "common_header.hpp" -#include "gridwise_5ary_Elementwise_1d.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -29,7 +35,7 @@ template -struct Device5AryElementwise : public BaseOperator +struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor> { static constexpr auto I0 = Number<0>{}; @@ -262,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator return true; }; - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - const CDataType* p_c, - const DDataType* p_d, - const EDataType* p_e, - FDataType* p_f, + static auto MakeArgument(std::array p_inputs, + std::array p_outputs, std::vector lengths, std::vector a_strides, std::vector b_strides, @@ -277,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator std::vector f_strides, ElementwiseFunctor functor) { - return Argument{p_a, - p_b, - p_c, - p_d, - p_e, - p_f, + return Argument{static_cast(p_inputs[0]), + static_cast(p_inputs[1]), + static_cast(p_inputs[2]), + static_cast(p_inputs[3]), + static_cast(p_inputs[4]), + static_cast(p_outputs[0]), lengths, a_strides, b_strides, @@ -293,39 +295,57 @@ struct Device5AryElementwise : public BaseOperator functor}; } - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - const void* p_c, - const void* p_d, - const void* p_e, - void* p_f, - std::vector lengths, - std::vector a_strides, - std::vector b_strides, - std::vector c_strides, - std::vector d_strides, - std::vector e_strides, - std::vector f_strides, - ElementwiseFunctor functor) + std::unique_ptr + MakeArgumentPointer(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector> input_strides, + std::vector> output_strides, + ElementwiseFunctor functor) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - static_cast(p_d), - static_cast(p_e), - static_cast(p_f), + return std::make_unique(static_cast(p_inputs[0]), + static_cast(p_inputs[1]), + static_cast(p_inputs[2]), + static_cast(p_inputs[3]), + static_cast(p_inputs[4]), + static_cast(p_outputs[0]), lengths, - a_strides, - b_strides, - c_strides, - d_strides, - e_strides, - f_strides, + input_strides[0], + input_strides[1], + input_strides[2], + input_strides[3], + input_strides[4], + output_strides[0], functor); } static auto MakeInvoker() { return Invoker{}; } - std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Device5aryElementwise" + << "<" + << "NDim = " << NDim + << "MPerThread = " << MPerThread + << "AScalarPerVector = " << AScalarPerVector + << "BScalarPerVector = " << BScalarPerVector + << "CScalarPerVector = " << CScalarPerVector + << "DScalarPerVector = " << DScalarPerVector + << "EScalarPerVector = " << EScalarPerVector + << "FScalarPerVector = " << FScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } }; // namespace device } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 1f6319d3f75e446da6ae2fc78b99a3a381df2e13..f41f65d76b5572efad1380f677b60e1e9f1dc3c6 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,8 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include -#include "stream_config.hpp" +#include "ck/stream_config.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +18,8 @@ struct BaseArgument BaseArgument& operator=(const BaseArgument&) = default; virtual ~BaseArgument() {} + + void* p_workspace_ = nullptr; }; struct BaseInvoker @@ -42,7 +47,11 @@ struct BaseOperator virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } - virtual void SetWorkSpacePointer(BaseArgument*, void*) const {} + virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + { + assert(p_arg); + p_arg->p_workspace_ = p_workspace; + } virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e755913280f73e325daf1352324bab8d3df8a3a9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp similarity index 52% rename from include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp rename to include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp index 9f5d16a1f9b940ff01e79ad66e13196a2fd69a06..90c8f79d865c02f3377997e8464f2beda58e3711 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp @@ -1,30 +1,38 @@ #pragma once #include +#include + #include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { +struct BatchedGemmCPermuteDesc +{ + ck::index_t G0_, G1_, M_, N_; + ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; +}; + template -struct DeviceGemmBias : public BaseOperator +struct DeviceBatchedGemmCPermute : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, - const void* p_bias, void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) = 0; + CElementwiseOperation c_element_op, + ck::index_t BatchCount) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; @@ -32,8 +40,8 @@ struct DeviceGemmBias : public BaseOperator template -using DeviceGemmBiasPtr = std::unique_ptr< - DeviceGemmBias>; +using DeviceBatchedGemmCPermutePtr = std::unique_ptr< + DeviceBatchedGemmCPermute>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fc65c811121cf335bd192d4ed83ef5815a1faffd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp @@ -0,0 +1,860 @@ +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + 0>{}, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto + MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) + { + const auto c_grid_desc_mraw_nraw = [&]() { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(stride_M, stride_N)); + }(); + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, + index_t G1, + index_t MRaw, + index_t NRaw, + index_t stride_G0, + index_t stride_G1, + index_t stride_M, + index_t stride_N) + { + const auto e_grid_desc_g0_g1_mraw_nraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(G0, G1, MRaw, NRaw), + make_tuple(stride_G0, stride_G1, stride_M, stride_N)); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_pass_through_transform(MRaw), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else + { + // not pad M or N + return e_grid_desc_g0_g1_mraw_nraw; + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1)); + using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, + index_t Batchstride_B, + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n) + : Batchstride_A_(Batchstride_A), + Batchstride_B_(Batchstride_B), + e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_A_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_B_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1); + index_t b0 = g_idx / G1; + index_t b1 = g_idx - b0 * G1; // g_idx % G1 + return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0)); + } + + private: + index_t Batchstride_A_; + index_t Batchstride_B_; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + }; + + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, // CShuffleDataType, + ck::Tuple<>, // DsDataType, + CDataType, // EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + NumPrefetch, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + BatchCount_(BatchCount), + a_grid_desc_k0_m_k1_{ + DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)}, + b_grid_desc_k0_n_k1_{ + DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)}, + c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N( + batched_gemm_c_permute_desc.M_, + batched_gemm_c_permute_desc.N_, + batched_gemm_c_permute_desc.stride_M_, + batched_gemm_c_permute_desc.stride_N_)}, + e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N( + batched_gemm_c_permute_desc.G0_, + batched_gemm_c_permute_desc.G1_, + batched_gemm_c_permute_desc.M_, + batched_gemm_c_permute_desc.N_, + batched_gemm_c_permute_desc.stride_G0_, + batched_gemm_c_permute_desc.stride_G1_, + batched_gemm_c_permute_desc.stride_M_, + batched_gemm_c_permute_desc.stride_N_)}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + compute_ptr_offset_of_batch_{ + type_convert(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), + type_convert(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), + e_grid_desc_g0_g1_m_n_}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + index_t BatchCount_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmCPermuteXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_c_permute_xdl< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.BatchCount_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + stride_A, + stride_B, + batched_gemm_c_permute_desc, + a_element_op, + b_element_op, + c_element_op, + BatchCount}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + stride_A, + stride_B, + batched_gemm_c_permute_desc, + a_element_op, + b_element_op, + c_element_op, + BatchCount); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmCPermuteXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp index dc2a7a72ab3c3f4406aa029c5d2ed2b3a0a9094b..1486f0ac73d00c35cf54eb29760a4409d5906512 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -1,14 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_gemm_reduce.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" -#include "gemm_specialization.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -17,16 +23,16 @@ namespace device { template @@ -38,18 +44,18 @@ __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - DPtrsGlobal p_ds_grid, + ReducePtrsGlobal p_reduces_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, - const DxsInElementwiseOperation dxs_in_element_op, - const DxsAccElementwiseOperation dxs_out_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { @@ -65,10 +71,10 @@ __global__ void const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); - static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); - p_ds_grid(In) = p_ds_grid(In) + d_batch_offset; + p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; }); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -76,36 +82,36 @@ __global__ void GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_c_grid + c_batch_offset, - p_ds_grid, + p_reduces_grid, p_shared, a_element_op, b_element_op, c_element_op, - dxs_in_element_op, - dxs_out_element_op, + reduce_in_element_ops, + reduce_out_element_ops, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, - d_grid_desc_mblock_mperblock, + reduce_grid_desc_mblock_mperblock, block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = p_ds_grid; + ignore = p_reduces_grid; ignore = batch_count; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; - ignore = dxs_in_element_op; - ignore = dxs_out_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; ignore = a_grid_desc_ak0_m_ak1; ignore = b_grid_desc_bk0_n_bk1; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = d_grid_desc_mblock_mperblock; + ignore = reduce_grid_desc_mblock_mperblock; ignore = compute_base_ptr_of_batch_; ignore = block_2_ctile_map; -#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) +#endif } // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle @@ -120,14 +126,14 @@ template -struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce +struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> { using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; @@ -440,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), - type_convert(d_grid_desc_m_.GetElementSpaceSize())}, + type_convert(reduce_grid_desc_m_.GetElementSpaceSize())}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, - dxs_in_element_op_{dxs_in_element_op}, - dxs_out_element_op_{dxs_out_element_op} + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} { if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, @@ -621,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; @@ -721,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; @@ -764,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch) { - return Argument{p_a, - p_b, - p_c, - p_dxs, - MRaw, - NRaw, - KRaw, + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, StrideA, StrideB, StrideC, a_element_op, b_element_op, c_element_op, - dxs_in_element_op, - dxs_out_element_op, - BatchCount}; + reduce_in_element_ops, + reduce_out_element_ops, + Batch}; } static auto MakeInvoker() { return Invoker{}; } // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - DPtrsGlobal p_dxs, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - DxsInElementwiseOperation dxs_in_element_op, - DxsAccElementwiseOperation dxs_out_element_op, - index_t BatchCount) override + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch = 1) override { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - p_dxs, - MRaw, - NRaw, - KRaw, + reduce_tuple, + M, + N, + K, StrideA, StrideB, StrideC, a_element_op, b_element_op, c_element_op, - dxs_in_element_op, - dxs_out_element_op, - BatchCount); + reduce_in_element_ops, + reduce_out_element_ops, + Batch); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp index d1ffa9df1476b4ff3609e7a33558e3270345f95c..ee94290a9d2440c7cf65188852130f3904ab9d31 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -1,16 +1,20 @@ -#ifndef DEVICE_BATCHED_GEMM_XDL_HPP -#define DEVICE_BATCHED_GEMM_XDL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -109,7 +113,7 @@ __global__ void ignore = c_element_op; ignore = compute_ptr_offset_of_batch; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif } template -struct DeviceBatchedGemmXdl - : public DeviceGemm +struct DeviceBatchedGemmXdl : public DeviceBatchedGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -330,26 +341,26 @@ struct DeviceBatchedGemmXdl index_t StrideA, index_t StrideB, index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t BatchCount) + CElementwiseOperation c_element_op) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, - BatchCount_(BatchCount), + Batch_(Batch), a_grid_desc_k0_m_k1_{ DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, b_grid_desc_k0_n_k1_{ DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - compute_ptr_offset_of_batch_{ - type_convert(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), - type_convert(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), - type_convert(c_grid_desc_m_n_.GetElementSpaceSize())}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, block_2_ctile_map_{ GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, M01_{M01}, @@ -372,7 +383,7 @@ struct DeviceBatchedGemmXdl const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - index_t BatchCount_; + index_t Batch_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; @@ -416,7 +427,7 @@ struct DeviceBatchedGemmXdl } const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -447,7 +458,7 @@ struct DeviceBatchedGemmXdl arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, - arg.BatchCount_, + arg.Batch_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, @@ -481,7 +492,7 @@ struct DeviceBatchedGemmXdl arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, - arg.BatchCount_, + arg.Batch_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, @@ -532,10 +543,13 @@ struct DeviceBatchedGemmXdl index_t StrideA, index_t StrideB, index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t BatchCount) + CElementwiseOperation c_element_op) { return Argument{p_a, p_b, @@ -546,12 +560,15 @@ struct DeviceBatchedGemmXdl StrideA, StrideB, StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, 1, 1, a_element_op, b_element_op, - c_element_op, - BatchCount}; + c_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -566,10 +583,13 @@ struct DeviceBatchedGemmXdl index_t StrideA, index_t StrideB, index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t BatchCount) override + CElementwiseOperation c_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -580,12 +600,15 @@ struct DeviceBatchedGemmXdl StrideA, StrideB, StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, 1, 1, a_element_op, b_element_op, - c_element_op, - BatchCount); + c_element_op); } // polymorphic @@ -616,4 +639,3 @@ struct DeviceBatchedGemmXdl } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index 34b3a59c7473cbebb26c3aafc056ce84c90d1591..99be946e92e7fcb7f64dc19432308a3dc8cd6714 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -1,10 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "gridwise_binary_elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" namespace ck { namespace tensor_operation { @@ -20,7 +26,7 @@ template -struct DeviceBinaryElementwise : public BaseOperator +struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor> { static constexpr auto I0 = Number<0>{}; @@ -193,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator return true; }; - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - std::vector lengths, - std::vector a_strides, - std::vector b_strides, - std::vector c_strides, - ElementwiseFunctor functor) + virtual std::unique_ptr + MakeArgumentPointer(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector> input_strides, + std::vector> output_strides, + ElementwiseFunctor functor) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), + return std::make_unique(static_cast(p_inputs[0]), + static_cast(p_inputs[1]), + static_cast(p_outputs[0]), lengths, - a_strides, - b_strides, - c_strides, + input_strides[0], + input_strides[1], + output_strides[0], functor); } - std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + // polymorphic std::string GetTypeString() const override { auto str = std::stringstream(); @@ -221,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator // clang-format off str << "DeviceBinaryElementwise" << "<" + << "NDim = " << NDim << "MPerThread = " << MPerThread + << "AScalarPerVector = " << AScalarPerVector + << "BScalarPerVector = " << BScalarPerVector + << "CScalarPerVector = " << CScalarPerVector << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index ad4fde750fc5472e830dedaa8c37c737a9356ea0..aedae53800b03cb78e6af5b15c8c6dfcd9792eea 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,28 +1,6 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "device_base.hpp" diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp index 4e1aada6dae094520e3ae14a58f8e449da5d6321..ac6b23479c50458362b05034f4babbe3c326d015 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -1,42 +1,23 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_gemm.hpp" -#include "device_cgemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdl_cshuffle_v1.hpp" -#include "binary_element_wise_operation.hpp" -#include "gridwise_binary_elementwise_1d.hpp" -#include "tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -557,11 +538,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle float ave_time = 0; - using Add = - ck::tensor_operation::binary_element_wise::Add; - using Substract = ck::tensor_operation::binary_element_wise:: - Substract; - using GridwiseBinAdd = GridwiseBinaryElementwise_1D; - using GridwiseBinSubstract = GridwiseBinaryElementwise_1D; - const auto add_kernel = kernel_binary_elementwise_1d; + const auto add_kernel = kernel_binary_elementwise_1d; - const auto substract_kernel = kernel_binary_elementwise_1d; + const auto subtract_kernel = kernel_binary_elementwise_1d; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { @@ -653,7 +632,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle // c_real = aux - aux_2 ave_time += launch_and_time_kernel(stream_config, - substract_kernel, + subtract_kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -663,7 +642,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_m_, arg.c_grid_desc_m_, arg.c_grid_desc_m_, - Substract{}); + Subtract{}); ave_time += launch_and_time_kernel(stream_config, @@ -764,7 +743,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle // c_real = aux - aux_2 ave_time += launch_and_time_kernel(stream_config, - substract_kernel, + subtract_kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -774,7 +753,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_m_, arg.c_grid_desc_m_, arg.c_grid_desc_m_, - Substract{}); + Subtract{}); ave_time += launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa0f07d3797c824a249719e084e27a155617782d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::vector a_ms_ks_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b130290fbe358af88549dcdb38e4e72f6aa8cec3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,981 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD_Xdl_CShuffle + : public DeviceContractionMultipleD +{ + using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_ms_ks_lengths_vec, + const std::vector& a_ms_ks_strides_vec) + { + assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && + a_ms_ks_strides_vec.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0); + const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_ns_ks_lengths_vec, + const std::vector& b_ns_ks_strides_vec) + { + assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && + b_ns_ks_strides_vec.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0); + const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_vec, + const std::vector& e_ms_ns_strides_vec) + { + assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN && + e_ms_ns_strides_vec.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0); + const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(e_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + e_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + e_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return e_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = + decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector{}, std::vector{})); + using BGridDesc_BK0_N_BK1 = + decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector{}, std::vector{})); + using EGridDesc_M_N = + decltype(MakeEGridDescriptor_M_N(std::vector{}, std::vector{})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + std::vector a_ms_ns_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_ms_ns_lengths, a_ms_ks_strides)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_ns_ks_lengths, b_ns_ks_strides)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; + a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; + + b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; + b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!(arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::vector a_ms_ns_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_ms_ns_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::vector a_ms_ns_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_ms_ns_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 8404f4c266e3ee05455d9393ad6ff0c968a174c5..31b2ca05e66b2a283601c3b50c286cdd14132aab 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,17 +1,21 @@ -#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_backward_weight.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_bwd_weight.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -773,4 +777,3 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 83953e59bd90e5137daaf763f76096f1769a198a..37ef8db332d96b597eeca1258070feab3b5d2733 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -1,17 +1,20 @@ -#ifndef DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_bwd_data.hpp" -#include "convolution_backward_data_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -821,4 +824,3 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index cc1c2cb2ca78d6a7444bb380efc863d5d7d56557..5b880b1fd64ec404cc24e9406dd94052c547cd14 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -1,17 +1,20 @@ -#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_fwd_bias_activation_add.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r3.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -963,4 +966,3 @@ struct } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index a397b5e2b13d397e787215a9fce98f6a5156f6cb..bab9898785f69655dd75fef3e0029c2192923710 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -1,15 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_fwd_bias_activation.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r2.hpp" +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 707413dfd3fdf56e81b4f6ae0c8d83f47ab53770..0e7d9cd4a807e02fad6ca97a48b7bbaefa0fe7f5 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,17 +1,20 @@ -#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_fwd.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r1.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -868,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvFwdSpecializationStr(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ">"; // clang-format on @@ -879,4 +882,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index ece18459a0c24476a8ed441ba23d2da65fd1d586..84166d6f5f2d2cc0789ecf87a12e2593062a8f66 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,17 +1,20 @@ -#ifndef DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_fwd.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -708,15 +711,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvFwdSpecializationStr(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ">"; // clang-format on return str.str(); } -}; // namespace device +}; } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp index b1eea0b33f37a095cbc9f44b3c572195e4f36c0d..f69d8f18ae036b8e7cd060a80a0cca12bfc2a050 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef DEVICE_CONV3D_FWD_NAIVE_HPP #define DEVICE_CONV3D_FWD_NAIVE_HPP diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 256d0f81e96c899777806121962be9c89df7c059..b48cfac0d8bdd4dc715eab4c13d0874e31f9151c 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef DEVICE_CONV3D_FWD_XDL_HPP #define DEVICE_CONV3D_FWD_XDL_HPP diff --git a/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp b/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp index 549cfb26f3d095b14b357136ab7a6618b0b7820e..f1712025308fcc57ebba0183ecb6e3cf95596784 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp @@ -1,8 +1,12 @@ -#ifndef DEVICE_CONV_WRW_HPP -#define DEVICE_CONV_WRW_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include #include -#include "device_base.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { namespace tensor_operation { @@ -44,4 +48,3 @@ using DeviceConvBwdWeightPtr = std::unique_ptr< } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp index 1d08af1a05edde477e75e43c37ed8cec0dd13cd7..83c19703b8bb430031ab57eb7aae5cd04a8c9e09 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -1,9 +1,13 @@ -#ifndef DEVICE_CONV_BWD_DATA_HPP -#define DEVICE_CONV_BWD_DATA_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include #include -#include "device_base.hpp" -#include "element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { namespace tensor_operation { @@ -44,4 +48,3 @@ using DeviceConvBwdDataPtr = std::unique_ptr< } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp index d53e56f18ba506ca1a67aa308ca4eec9311e89b3..5a3fb60d3bab9267148fa30b77ac5d967fd02ede 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -1,8 +1,12 @@ -#ifndef DEVICE_CONV_FWD_HPP -#define DEVICE_CONV_FWD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include -#include "device_base.hpp" +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { namespace tensor_operation { @@ -43,4 +47,3 @@ using DeviceConvFwdPtr = std::unique_ptr< } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp index 77d4b7fb95a7c414df3b81a024df10332c0c42cd..5a627deeb2221f5e271532d45bc8a544c8cceeba 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -1,8 +1,13 @@ -#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP -#define DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include #include -#include "device_base.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { namespace tensor_operation { @@ -46,4 +51,3 @@ using DeviceConvFwdBiasActivationPtr = } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp index 2f8e780b78d9b6feaadee9718e4bc1f39b2f96e1..cc139303c929ae999d182ad09e9a4ef33f5209e1 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -1,8 +1,12 @@ -#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP -#define DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include #include -#include "device_base.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { namespace tensor_operation { @@ -47,4 +51,3 @@ using DeviceConvFwdBiasActivationAddPtr = } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index dde9e0f87390f97ed924ec033c318b6ba580d2cc..32d91269b4300d5266ca7a0e3b087035287affbe 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,16 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_backward_weight.hpp" -#include "convolution_backward_weight_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_bwd_weight.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -432,7 +437,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using namespace ck; const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[2]; + const index_t Hi = input_spatial_lengths[1]; const index_t Wi = input_spatial_lengths[2]; const index_t Do = output_spatial_lengths[0]; @@ -628,6 +633,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ 1); } + // type convert descs + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * 4; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + template + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using TypeConvertFp32ToBf16Functor = + ck::tensor_operation::element_wise::UnaryTypeConvert; + using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; @@ -733,6 +789,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ true, true>; + using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); @@ -881,18 +986,67 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + float ave_time = 0; + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - float ave_time = 0; + const auto run_conv = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); - const auto Run = [&](const auto& kernel) { hipGetErrorString(hipMemset( arg.p_c_grid_, 0, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * sizeof(CDataType))); - ave_time = + launch_and_time_kernel(StreamConfig{nullptr, false}, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + + return elapsed_time; + }; + + // run kernel for bf16 with splitk + const auto run_bf16_splitk = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_workspace_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(AccDataType))); + + float elapsed_time = launch_and_time_kernel(stream_config, kernel, dim3(grid_size), @@ -900,7 +1054,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ 0, arg.p_a_grid_, arg.p_b_grid_, - arg.p_c_grid_, + static_cast(arg.p_workspace_), arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -908,49 +1062,77 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ arg.b_element_op_, arg.c_element_op_, arg.block_2_ctile_map_); + + hipGetErrorString(hipMemset( + arg.p_workspace_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(AccDataType))); + + launch_and_time_kernel(StreamConfig{nullptr, false}, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + static_cast(arg.p_workspace_), + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + + return elapsed_time; }; - if constexpr(std::is_same::value) + // kernel for type conversion + std::vector filter_dims{static_cast(arg.Conv_K_), + static_cast(arg.Conv_C_)}; + + filter_dims.insert(std::end(filter_dims), + std::begin(arg.filter_spatial_lengths_), + std::end(arg.filter_spatial_lengths_)); + + int tensor_size = + std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies{}); + + const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size); + GridDesc_M0 a_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + GridDesc_M0 b_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + + if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_)) { - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - true>; - - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - false>; - - Run(kernel); - } + throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting"); } - else + + // run kernel for type conversion + void* p_c_grid_tmp_ = static_cast(arg.p_c_grid_); + InDataType* p_c_grid_tmp_bf16_ = static_cast(p_c_grid_tmp_); + const auto run_type_convert = [&](const auto& kernel) { + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(type_convert_grid_size), + dim3(256), + 0, + static_cast(arg.p_workspace_), + p_c_grid_tmp_bf16_, + a_grid_desc_m0_, + b_grid_desc_m0_, + TypeConvertFp32ToBf16Functor{}); + return elapsed_time; + }; + + if constexpr(std::is_same::value) { - if(has_main_k0_block_loop) - { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + if(kbatch == 1) { const auto kernel = kernel_gemm_xdlops_bwd_weight< @@ -965,16 +1147,23 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - true>; + has_main_loop>; - Run(kernel); + return run_conv(kernel); } else { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemmAtomicAdd, + const auto kernel_type_convert = + kernel_unary_elementwise_1d; + + const auto kernel_conv = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAddFloatBf16Splitk, ADataType, // TODO: distiguish A/B datatype - CDataType, + AccDataType, remove_reference_t, remove_reference_t, remove_reference_t< @@ -983,13 +1172,28 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - true>; + has_main_loop>; - Run(kernel); + float elapsed_time = 0; + elapsed_time += run_bf16_splitk(kernel_conv); + elapsed_time += run_type_convert(kernel_type_convert); + return elapsed_time; } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); } else { + ave_time = launch_kernel(integral_constant{}); + } + } + else + { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + if(kbatch == 1) { const auto kernel = kernel_gemm_xdlops_bwd_weight< @@ -1004,9 +1208,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - false>; + has_main_loop>; - Run(kernel); + return run_conv(kernel); } else { @@ -1022,10 +1226,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - false>; + has_main_loop>; - Run(kernel); + return run_conv(kernel); } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); } } @@ -1047,6 +1259,20 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static bool IsSupportedArgument(const Argument& arg) { + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NumDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + // vector load A/B matrix from global memory if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && @@ -1171,6 +1397,12 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ << NPerBlock << ", " << K0PerBlock << ">"; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0){ + + str << " Filter1x1Stride1Pad0"; + } + // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp index 0517db44154c09603c4dc69e5efefef0c97b7cfc..a5970c8f13cbd2fe2ace3e94bc5f212fa28de77a 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,17 +1,20 @@ -#ifndef DEVICE_CONVND_BWD_DATA_XDL_NDHWC_KZYXC_NDHWK_HPP -#define DEVICE_CONVND_BWD_DATA_XDL_NDHWC_KZYXC_NDHWK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_bwd_data.hpp" -#include "convolution_backward_data_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -1546,4 +1549,3 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index c1ab44a28b39ac82f403762cd8bd920a954527df..78f0f02898451c26b13cf495a246aeffa7eb2703 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include @@ -6,16 +9,15 @@ #include #include -#include "device.hpp" -#include "device_prop.hpp" -#include "device_base.hpp" -#include "device_conv_fwd.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -1031,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvFwdSpecializationStr(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f0946eb846adf71c149d815282b9731988b6443d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwise : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector> input_strides, + std::vector> output_strides, + ElementwiseFunctor functor) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceElementwisePtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp index 4576aaa7e032b2bc6ceab021018f94f005251d32..04b6e0c13e427e74b04a464bba12d94c7d096062 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { namespace tensor_operation { @@ -14,33 +18,52 @@ struct GemmShape ck::index_t StrideA, StrideB, StrideC; }; -template struct DeviceGemm : public BaseOperator { - virtual std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - ck::index_t KBatch = 1) = 0; + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; -template -using DeviceGemmPtr = std::unique_ptr< - DeviceGemm>; +using DeviceGemmPtr = std::unique_ptr>; template -#include "device_base.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceGemmBiasActivationAdd : public BaseOperator -{ - virtual std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - const void* p_c0, - const void* p_c1, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - ck::index_t StrideC1, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - ck::index_t KBatch = 1) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -using DeviceGemmBiasActivationAddPtr = - std::unique_ptr>; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1aa3885523c27df0b52103cbcb2540c5241222a9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -0,0 +1,875 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0)); + using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + BiasDataType, + D0DataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, + c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bde0d48c15e012edfbe2abf930c029f02f82059c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DEGridDesc_M0_M1_M2_N0_N1 +{ + ck::index_t M0_, M1_, M2_, N0_, N1_; + ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_; +}; + +// input : A[M, K], B[K, N], +// input : D[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D) +template +struct DeviceGemmBiasCPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasCPermutePtr = std::unique_ptr< + DeviceGemmBiasCPermute>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f74cb0dc84021ce4140f7be42ef8f3639b0c190c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_c_permute(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// input : A[M, K], or A[K, N] +// input : B[K, N], or A[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute +{ + using DeviceOp = DeviceGemmBiasCPermute_Xdl; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t NumDTensor = I1; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) + { + index_t M0 = d_e_grid_desc.M0_; + index_t M1 = d_e_grid_desc.M1_; + index_t M2 = d_e_grid_desc.M2_; + index_t N0 = d_e_grid_desc.N0_; + index_t N1 = d_e_grid_desc.N1_; + + index_t stride_M0 = d_e_grid_desc.stride_M0_; + index_t stride_M1 = d_e_grid_desc.stride_M1_; + index_t stride_M2 = d_e_grid_desc.stride_M2_; + index_t stride_N0 = d_e_grid_desc.stride_N0_; + index_t stride_N1 = d_e_grid_desc.stride_N1_; + + const auto MRaw = M0 * M1 * M2; + const auto NRaw = N0 * N1; + + const auto c_grid_desc_mraw_nraw = [&]() { + const auto c_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor( + make_tuple(M0, M1, M2, N0, N1), + make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); + + return transform_tensor_descriptor( + c_grid_desc_m0_m1_m2_n0_n1, + make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), + make_merge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + const void* p_d_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + + if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + p_ds_grid_(I0) = static_cast(p_d_grid); + + const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_bias_c_permute< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasCPermute_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp index 5ccf1934fee5310af214bab238926b3fd964289b..aca413f8a0456eaa43fea1e431e59d6ac276a31b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -1,19 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include -#include "device.hpp" -#include "device_prop.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gemm_specialization.hpp" -#include "element_wise_operation.hpp" -#include "gridwise_gemm_dl_v1r3.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -63,8 +64,16 @@ template < is_same_v && is_same_v, bool> = false> -struct DeviceGemmDl - : public DeviceGemm +struct DeviceGemmDl : public DeviceGemm + { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -533,8 +542,7 @@ struct DeviceGemmDl index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override + CElementwiseOperation c_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c0594e38cfeae95508342e136f9445a21f4fdc4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmMultipleDPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c7f2c15f043878af25c31a9a6b4323d3b76b6d2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,758 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + } + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index 7e387049c7dd04df7a15db37d610d7fc549e8721..fcc088ca43d1b8c6224c0015e3eb7434038af4b2 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -1,52 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include + #include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { -template +// FIXME: DeviceGemmReduce type need to well define the problem +template struct DeviceGemmReduce : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, + const void* p_bias, + std::array p_ds, void* p_c, - DPtrsGlobal p_dxs, + std::array p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - DxsInElementwiseOperation dxs_in_element_op, - DxsAccElementwiseOperation dxs_out_element_op, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_ops, + std::array reduce_out_element_ops, ck::index_t BatchCount = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; -template -using DeviceGemmReducePtr = std::unique_ptr>; +template +using DeviceGemmReducePtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index f36db1a9e0ec75be2ff0f7038daa1ea543fef03d..722ae1137befc39846562a32efc1b98cc5e2c24e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -1,14 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_gemm_reduce.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" -#include "gemm_specialization.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -26,14 +32,14 @@ template -struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> { using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; @@ -345,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; @@ -571,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; @@ -611,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(p_arg)); } - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - DPtrsGlobal p_dxs, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - DxsInElementwiseOperation dxs_in_element_op, - DxsAccElementwiseOperation dxs_out_element_op) + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) { - return Argument{p_a, - p_b, - p_c, - p_dxs, - MRaw, - NRaw, - KRaw, + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, StrideA, StrideB, StrideC, a_element_op, b_element_op, c_element_op, - dxs_in_element_op, - dxs_out_element_op}; + reduce_in_element_ops, + reduce_out_element_ops}; } static auto MakeInvoker() { return Invoker{}; } // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - DPtrsGlobal p_dxs, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - DxsInElementwiseOperation dxs_in_element_op, - DxsAccElementwiseOperation dxs_out_element_op, - index_t /* KBatch */ = 1) override + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t = 1) override { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - p_dxs, - MRaw, - NRaw, - KRaw, + reduce_tuple, + M, + N, + K, StrideA, StrideB, StrideC, a_element_op, b_element_op, c_element_op, - dxs_in_element_op, - dxs_out_element_op); + reduce_in_element_ops, + reduce_out_element_ops); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp similarity index 53% rename from include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp rename to include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp index 95736b1887042159388a9dd4dc8ea971aee72d5b..c701bff57f8eb7db155e9abdfd3ab7210e7eeffd 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -1,22 +1,31 @@ -#ifndef DEVICE_GEMM_BIAS_ACTIVATION_HPP -#define DEVICE_GEMM_BIAS_ACTIVATION_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include +#include + #include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceGemmBiasActivation : public BaseOperator +struct DeviceGemmSplitK : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - const void* p_c0, ck::index_t M, ck::index_t N, ck::index_t K, @@ -26,18 +35,30 @@ struct DeviceGemmBiasActivation : public BaseOperator AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, - ck::index_t KBatch = 1) = 0; + ck::index_t KBatch) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; -template -using DeviceGemmBiasActivationPtr = std::unique_ptr< - DeviceGemmBiasActivation>; +using DeviceGemmSplitKPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index 3a8e1390e473dd5d02753a38ce8b3a730ecbbf95..98028e1f2830f917305535a30f27226eec3dce28 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -1,17 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include -#include "device.hpp" -#include "device_prop.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "gemm_specialization.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -54,8 +57,15 @@ template -struct DeviceGemmXdl - : public DeviceGemm +struct DeviceGemmXdl : public DeviceGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -484,8 +494,7 @@ struct DeviceGemmXdl index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override + CElementwiseOperation c_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp deleted file mode 100644 index 1db69dd46207468f18136742f7ca1e038a633e45..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp +++ /dev/null @@ -1,506 +0,0 @@ -#pragma once -#include -#include -#include "device.hpp" -#include "device_gemm_bias.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r2.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t K0PerBlock, - ck::index_t K1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> -struct DeviceGemmXdl_C_Shuffle_Bias_2d - : public DeviceGemmBias -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static constexpr auto K1Number = Number{}; - - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_m_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_n_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - C0GridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - const CDataType* p_bias_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c0_grid_{p_bias_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c0_grid_desc_m_n_{}, - c_grid_desc_m_n_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_C_Shuffle_Bias_2d::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_C_Shuffle_Bias_2d::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c0_grid_desc_m_n_ = - DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); - c_grid_desc_m_n_ = - DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - const CDataType* p_c0_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - C0GridDesc_M_N c0_grid_desc_m_n_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting"); - } - - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - - float ave_time = 0; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - const auto kernel = kernel_gemm_xdlops_v3r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v3r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - const CDataType* p_bias, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_bias, - p_c, - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - const void* p_bias, - void* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_bias), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle_Bias_2d" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp deleted file mode 100644 index b465f8e4aeedaf946775acab031d634d184db5b9..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp +++ /dev/null @@ -1,516 +0,0 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_HPP - -#include -#include -#include "device.hpp" -#include "device_gemm_bias_activation.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r2.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t K0PerBlock, - ck::index_t K1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> -struct DeviceGemmXdl_C_Shuffle_Bias_Activation - : public DeviceGemmBiasActivation -{ - using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto K1Number = Number{}; - - static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N( - index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - // A[K0, M, K1] - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B[K0, N, K1] - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C[M, N] - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - // C0[N]: assume a contiguous vector - const auto c0_grid_desc_m_n = - make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); - - return make_tuple( - a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, c0_grid_desc_m_n); - } - - using GridDescs = - decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N(1, 1, 1, 1, 1, 1)); - - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - using C0GridDesc_M_N = remove_cvref_t; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - C0GridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - const CDataType* p_c0_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - p_c0_grid_{p_c0_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c0_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N( - M, N, K, StrideA, StrideB, StrideC); - - a_grid_desc_k0_m_k1_ = descs[I0]; - b_grid_desc_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; - c0_grid_desc_m_n_ = descs[I3]; - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - const CDataType* p_c0_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - C0GridDesc_M_N c0_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); - } - - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - - float ave_time = 0; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - const auto kernel = kernel_gemm_xdlops_v3r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v3r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - const CDataType* p_c0, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_c, - p_c0, - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - const void* p_c0, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - static_cast(p_c0), - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle_Bias_Activation" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp deleted file mode 100644 index 7a2e1886d354ce60948c4a1015fabda2c06dc88b..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp +++ /dev/null @@ -1,576 +0,0 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_HPP - -#include -#include -#include "device.hpp" -#include "device_gemm_bias_activation_add.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r3.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) + C1[M, N] -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t K0PerBlock, - ck::index_t K1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> -struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add - : public DeviceGemmBiasActivationAdd -{ - using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation_Add; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - - static constexpr auto K1Number = Number{}; - - static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t StrideC1) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - // A[K0, M, K1] - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B[K0, N, K1] - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C[M, N] - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - // C0[N]: assume a contiguous vector - const auto c0_grid_desc_m_n = - make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); - - // C1[M, N]: residual tensor: assume same layout as C - const auto c1_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC1, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC1)); - } - }(); - - return make_tuple(a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n, - c0_grid_desc_m_n, - c1_grid_desc_m_n); - } - - using GridDescs = - decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(1, 1, 1, 1, 1, 1, 1)); - - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - using C0GridDesc_M_N = remove_cvref_t; - using C1GridDesc_M_N = remove_cvref_t; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - C0GridDesc_M_N, - C1GridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - const CDataType* p_c0_grid, - const CDataType* p_c1_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t StrideC1, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - p_c0_grid_{p_c0_grid}, - p_c1_grid_{p_c1_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c0_grid_desc_m_n_{}, - c1_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N( - M, N, K, StrideA, StrideB, StrideC, StrideC1); - - a_grid_desc_k0_m_k1_ = descs[I0]; - b_grid_desc_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; - c0_grid_desc_m_n_ = descs[I3]; - c1_grid_desc_m_n_ = descs[I4]; - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c1_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - const CDataType* p_c0_grid_; - const CDataType* p_c1_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - C0GridDesc_M_N c0_grid_desc_m_n_; - C1GridDesc_M_N c1_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); - } - - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - - float ave_time = 0; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - const auto kernel = kernel_gemm_xdlops_v3r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.p_c1_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v3r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - remove_reference_t< - typename GridwiseGemm:: - C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.p_c1_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - const CDataType* p_c0, - const CDataType* p_c1, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t StrideC1, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_c, - p_c0, - p_c1, - M, - N, - K, - StrideA, - StrideB, - StrideC, - StrideC1, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - const void* p_c0, - const void* p_c1, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t StrideC1, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - static_cast(p_c0), - static_cast(p_c1), - M, - N, - K, - StrideA, - StrideB, - StrideC, - StrideC1, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle_Bias_Activation_Add" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index a74ee81679991c264d14c9386d71aa6c788e23f3..9c8b189add092b1aae37e550740e3d2deadafb98 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -1,15 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdl_cshuffle_v1.hpp" -#include "tensor_operation/gpu/device/gemm_specialization.hpp" -#include "device_prop.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -60,8 +65,15 @@ template -struct DeviceGemm_Xdl_CShuffle - : public DeviceGemm +struct DeviceGemm_Xdl_CShuffle : public DeviceGemm { using DeviceOp = DeviceGemm_Xdl_CShuffle; @@ -617,8 +629,7 @@ struct DeviceGemm_Xdl_CShuffle index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override + CElementwiseOperation c_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b82fcb67f7c6d5a9c449658504c3442eb7e9c6d3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator +{ + using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + static auto MakeGridDescriptor_N(index_t NRaw) + { + const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw)); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad N + return transform_tensor_descriptor(grid_desc_nraw, + make_tuple(make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad N + return grid_desc_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + C0DataType, + ReduceAccDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + LoopSched>; + + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const C0DataType* p_c0_grid_add, + const C0DataType* p_c0_grid_bias, + const C0DataType* p_c0_grid_gamma, + const C0DataType* p_c0_grid_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_bias_{p_c0_grid_bias}, + p_c0_grid_add_{p_c0_grid_add}, + p_c0_grid_gamma_{p_c0_grid_gamma}, + p_c0_grid_beta_{p_c0_grid_beta}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_nblock_nperblock_{}, + block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_nblock_nperblock_ = + GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const C0DataType* p_c0_grid_bias_; + const C0DataType* p_c0_grid_add_; + const C0DataType* p_c0_grid_gamma_; + const C0DataType* p_c0_grid_beta_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_N c0_grid_desc_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const C0DataType* p_c0_bias, + const C0DataType* p_c0_add, + const C0DataType* p_c0_gamma, + const C0DataType* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0_bias, + p_c0_add, + p_c0_gamma, + p_c0_beta, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0_bias, + const void* p_c0_add, + const void* p_c0_gamma, + const void* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0_bias), + static_cast(p_c0_add), + static_cast(p_c0_gamma), + static_cast(p_c0_beta), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + } + + std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmLayerNorm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index d9fc8f7a8a79e7a0f2a4d4354a58a6195ee61d7d..306a73dff15072ec18c7dc890f36a70234035517 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -1,22 +1,20 @@ -#ifndef DEVICE_GEMM_SPLITK_XDL_HPP -#define DEVICE_GEMM_SPLITK_XDL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r4.hpp" -#include "gemm_specialization.hpp" -#include "device_prop.hpp" - -#ifndef CK_RUN_KERNEL_AND_TIME -#define CK_RUN_KERNEL_AND_TIME 1 -#endif + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -58,8 +56,15 @@ template -struct DeviceGemmXdlSplitK - : public DeviceGemm +struct DeviceGemmXdlSplitK : public DeviceGemmSplitK { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -639,4 +644,3 @@ struct DeviceGemmXdlSplitK } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp index ad424d91d97eb6818a8980ad90abee0b9e49e395..52bdacf7dbe9d8cc2052b2642099eead654d5ba6 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp @@ -1,21 +1,20 @@ -#ifndef DEVICE_GEMM_XDL_SPLITK_C_SHUFFLE_HPP -#define DEVICE_GEMM_XDL_SPLITK_C_SHUFFLE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r4r2.hpp" -#include "gemm_specialization.hpp" - -#ifndef CK_RUN_KERNEL_AND_TIME -#define CK_RUN_KERNEL_AND_TIME 1 -#endif + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -59,8 +58,15 @@ template -struct DeviceGemmXdlSplitKCShuffle - : public DeviceGemm +struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -421,21 +427,22 @@ struct DeviceGemmXdlSplitKCShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * sizeof(CDataType))); - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; if(has_main_k0_block_loop) @@ -641,4 +648,3 @@ struct DeviceGemmXdlSplitKCShuffle } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index 0617b4fcb7f920ca5c522b95a79cfcf37aa3d027..999792807bd16f38e1681a9fa001f75cf3d5c170 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -1,17 +1,20 @@ -#ifndef DEVICE_GROUPED_GEMM_XDL_HPP -#define DEVICE_GROUPED_GEMM_XDL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "gemm_specialization.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -43,13 +46,22 @@ __global__ void const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - index_t group_id = 0; - for(index_t i = 0; i < group_count; i++) + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) { - group_id = - (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) - ? i - : group_id; + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } GridwiseGemm::template Run( @@ -362,7 +374,7 @@ struct DeviceGroupedGemmXdl { grid_size_ = 0; - gemm_descs_args_workspace_ = nullptr; + p_workspace_ = nullptr; group_count_ = ck::type_convert(gemm_shapes.size()); @@ -437,8 +449,6 @@ struct DeviceGroupedGemmXdl std::vector gemm_desc_kernel_arg_; - void* gemm_descs_args_workspace_; - index_t grid_size_; }; @@ -488,7 +498,7 @@ struct DeviceGroupedGemmXdl } hipGetErrorString( - hipMemcpy(arg.gemm_descs_args_workspace_, + hipMemcpy(arg.p_workspace_, arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), hipMemcpyHostToDevice)); @@ -507,17 +517,17 @@ struct DeviceGroupedGemmXdl CElementwiseOperation, true>; - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } else { @@ -531,17 +541,17 @@ struct DeviceGroupedGemmXdl CElementwiseOperation, false>; - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } return ave_time; @@ -635,14 +645,8 @@ struct DeviceGroupedGemmXdl { return dynamic_cast(p_arg)->group_count_ * sizeof(GemmDescKernelArg); } - - void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override - { - dynamic_cast(p_arg)->gemm_descs_args_workspace_ = workspace_ptr; - } }; } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0e4313f17d99dcb8ff1bb89906a3b5fab01085bd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceNormalization : public BaseOperator +{ + // inLengths: input tensor extent(s) from high to low dimension + // inStrides: input tensor stride(s) from high to low dimension + // reduceDims: the dimension(s) the normalization operation is applied + // alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType + // beta: typeless pointer in host memory storing the beta scaling value of type AccDataType + // in_dev: typeless const pointer in device memory storing the input tensor + // out_dev: typeless pointer in device memory storing the output tensor + virtual std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const void* alpha, + const void* beta, + const void* in_dev, + void* out_dev) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual index_t GetRank() const = 0; + + virtual index_t GetNumReduceDim() const = 0; +}; + +using DeviceNormalizationPtr = std::unique_ptr; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp index d049f6e97911d093b6e9b0203743d4ff3b385886..3b376c6f73f066729fa118013970682ca31b340e 100644 --- a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp @@ -1,10 +1,13 @@ -#ifndef DEVICE_POOL2D_FWD_HPP -#define DEVICE_POOL2D_FWD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "reduction_enums.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" namespace ck { namespace tensor_operation { @@ -35,4 +38,3 @@ using DevicePool2dFwdPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp index c7e18d98dcddec5f32c518fcd829e9c70b1cb952..3edf9bd3aff10df7d0545989bd7861d05d20a39e 100644 --- a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp @@ -1,13 +1,18 @@ -#ifndef DEVICE_POOL2D_FWD_NHWC_NHWC_HPP -#define DEVICE_POOL2D_FWD_NHWC_NHWC_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_pool2d_fwd.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "reduction_operator_mapping.hpp" -#include "gridwise_2d_reduction_threadwise.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -35,14 +40,13 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd using IndexDataType = int32_t; - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; static constexpr index_t InSrcOutDstVectorDim = 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is @@ -178,13 +182,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd invariant_lowest_length_ = C; reduce_lowest_length_ = window_spatial_lengths[1]; - // TODO: is this correct? - if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG) - { - ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; - in_element_op_ = InElementwiseOperation{divider}; - acc_element_op_ = AccElementwiseOperation{divider}; - } + int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); } const InDataType* p_in_dev_; @@ -319,9 +320,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd return str.str(); } -}; // namespace device +}; } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp index 6f367a8747c36d8e839bb6f85e127d70f75e71e3..468d0b5ab9ef1d2fbfc4f853e082020849299f9f 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -1,13 +1,15 @@ -#ifndef DEVICE_REDUCE_HPP -#define DEVICE_REDUCE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include #include -#include "common_header.hpp" -#include "device_base.hpp" -#include "reduction_enums.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { namespace tensor_operation { @@ -41,4 +43,3 @@ using DeviceReducePtr = } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp index f68a392821711e068f482eaf496eb7ee7b82c652..42e74f29931d51a87fcccf5c8d75e04b5f2b7826 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp @@ -1,12 +1,14 @@ -#ifndef DEVICE_REDUCE_COMMON_HPP -#define DEVICE_REDUCE_COMMON_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "common_header.hpp" -#include "reduction_enums.hpp" -#include "reduction_operator.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_operator.hpp" namespace ck { namespace tensor_operation { @@ -85,6 +87,4 @@ std::vector shuffle_tensor_dimensions(const std::vector& origL } // namespace device } // namespace tensor_operation - } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp index 575c6bff1db478aaa9d4ba7d8494889fe5fe9572..a903fc415e8f4d52552f7ffc91a5347804467900 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp @@ -1,15 +1,21 @@ -#ifndef DEVICE_REDUCE_MULTIBLOCK_HPP -#define DEVICE_REDUCE_MULTIBLOCK_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device.hpp" -#include "device_base.hpp" -#include "device_reduce.hpp" -#include "device_reduce_common.hpp" -#include "gridwise_2d_reduction_multiblock.hpp" -#include "gridwise_set_buffer_value.hpp" -#include "reduction_operator.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -61,12 +67,9 @@ struct DeviceReduceMultiBlock : public DeviceReduce::value || std::is_same::value; - - static_assert( - !use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op), - "The OutDataType must support the atomic operation for using MultiBlock reduction"); + static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType::value, + "The OutDataType must support the specified OutMemoryDataOperation!"); static_assert(!use_multiblock || (use_multiblock && !OutputIndex), "MultiBlock reduction can only be used when outputing index is not required"); @@ -349,7 +352,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce( + ck::reduce::GetIdentityValueForInMemoryDataOperation( OutMemoryDataOperation); const auto kernel_pre = @@ -393,10 +396,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce(p_arg); - if constexpr(use_multiblock) { if(static_cast(pArg->beta_) != 0.0f) @@ -445,11 +446,16 @@ struct DeviceReduceMultiBlock : public DeviceReducereduce_total_length / KThreadSliceSize < 2) - return (false); + // if(pArg->reduce_total_length / KThreadSliceSize < 2) + // return (false); }; return (true); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); }; std::unique_ptr @@ -492,7 +498,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce"; @@ -505,4 +511,3 @@ struct DeviceReduceMultiBlock : public DeviceReduce #include -#include "device.hpp" -#include "device_reduce.hpp" -#include "device_reduce_common.hpp" -#include "gridwise_2d_reduction_multiblock.hpp" -#include "gridwise_2d_reduction_threadwise.hpp" + +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" namespace ck { namespace tensor_operation { @@ -370,4 +374,3 @@ struct DeviceReduceThreadWise : public DeviceReduce +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmax : public DeviceNormalization +{ + static constexpr index_t kRank = Rank; + static constexpr index_t kNumReduceDim = NumReduceDim; + + virtual index_t GetRank() const override { return kRank; } + + virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } + + using PassThrough = tensor_operation::element_wise::PassThrough; + + // Used for freeloading of some handy functions from DeviceReduceMultiBlock + using Reduction = DeviceReduceMultiBlock; // OutDstVectorSize + + using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk; + + using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk; + + struct Argument : public Reduction::Argument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + AccDataType alpha, + AccDataType beta, + const InDataType* in_dev, + OutDataType* out_dev) + : Reduction::Argument(inLengths, + inStrides, + {}, + {}, + reduceDims, + 0.0f, // alpha + 0.0f, // beta + in_dev, + nullptr, + out_dev, + nullptr, + PassThrough{}, + PassThrough{}), + // FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of + // float32 precision. Make it support any data type so the fields can be removed. + alpha_(alpha), + beta_(beta) + { + // std::cout << "blkGroupSize= " << this->blkGroupSize + // << ", numBlockTileIteration= " << this->numBlockTileIteration + // << ", gridSize=" << this->gridSize + // << ", invariant_total_length=" << this->invariant_total_length << + // std::endl; + } + + AccDataType alpha_; + AccDataType beta_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + + bool sweep_once = + in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + const auto kernel_main = sweep_once ? kernel_softmax + : kernel_softmax; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m_k, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.beta_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if(!Reduction::IsSupportedArgument(p_arg_)) + { + return false; + } + + if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0) + { + return false; + } + + return true; + }; + + // inLengths: input tensor extent(s) from high to low dimension + // inStrides: input tensor stride(s) from high to low dimension + // reduceDims: the dimension(s) the softmax normalization operate on + // alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType + // beta: typeless pointer in host memory storing the beta scaling value as type AccDataType + // in_dev: typeless const pointer in device memory storing the input tensor + // out_dev: typeless pointer in device memory storing the output tensor + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const void* alpha, + const void* beta, + const void* in_dev, + void* out_dev) override + { + return std::make_unique(inLengths, + inStrides, + reduceDims, + *static_cast(alpha), + *static_cast(beta), + static_cast(in_dev), + static_cast(out_dev)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceSoftmax<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..054245429d693564ddaf4b772802e0079acd7a64 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceUnaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + BDataType* p_b, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + shape_(shape), + functor_(functor), + blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future + { + index_t tensor_size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size); + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); + } + + const ADataType* p_a_; + BDataType* p_b_; + std::vector shape_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_unary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->shape_.back() % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + void* p_b, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + shape, + stride_a, + stride_b, + functor); + } + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index d4ef61a133a478e3237f3311c72b3213ca8430c3..927a92e6b4d1b2c2fcdb9ef324ffcb5f5e7a2173 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,7 @@ -#ifndef GEMM_SPECIALIZATION -#define GEMM_SPECIALIZATION +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { namespace tensor_operation { @@ -17,7 +19,22 @@ enum struct GemmSpecialization MNKPadding, }; +inline std::string getGemmSpecializationString(const GemmSpecialization& s) +{ + switch(s) + { + case GemmSpecialization::Default: return "Default"; + case GemmSpecialization::MPadding: return "MPadding"; + case GemmSpecialization::NPadding: return "NPadding"; + case GemmSpecialization::KPadding: return "KPadding"; + case GemmSpecialization::MNPadding: return "MNPadding"; + case GemmSpecialization::MKPadding: return "MKPadding"; + case GemmSpecialization::NKPadding: return "NKPadding"; + case GemmSpecialization::MNKPadding: return "MNKPadding"; + default: return "Unrecognized specialization!"; + } +} + } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp index 634e9212ea8ac1a92e496b86d6409faa3f99eb3f..d35318357a9bcc6fb087c92e27dcf435ff60796b 100644 --- a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -1,34 +1,13 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_OPERATOR_MAPPING_HPP -#define CK_REDUCTION_OPERATOR_MAPPING_HPP - -#include "reduction_operator.hpp" -#include "reduction_enums.hpp" -#include "element_wise_operation.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// FIXME: can it be replaced with ck::Tuple? +#include namespace ck { @@ -37,77 +16,69 @@ namespace ck { // The boolean member "indexable" are also provided in reduce_binary_operactor for // easier checking by the upper-layer codes in the kernels. -template +template struct reduce_binary_operator; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Mul; - using dataType = T; + using opType = reduce::Mul; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Min; - using dataType = T; + using opType = reduce::Min; static constexpr bool indexable = true; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Max; - using dataType = T; + using opType = reduce::Max; static constexpr bool indexable = true; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::AMax; - using dataType = T; + using opType = reduce::AMax; static constexpr bool indexable = true; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; @@ -115,55 +86,101 @@ struct reduce_binary_operator // The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary // functor classes. // The two unary functors are called before and afer the Reduction is executed respectively -template +template struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength}); + }; }; -template -struct reduce_unary_operator +template +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template <> +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template <> +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; - using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template <> +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; - using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -} // end of namespace ck - -#endif +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index 2409071b4828cc0a0dc08a5098fb39baaf596293..40c7eb7d5ec18c68d0256f917139a1dbf62a381f 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once namespace ck { diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 1032f0f8fc162bc481f60991d70c0a485b685a4c..ece1ecb865c8eeba06f654174dcd68dfdb07c480 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,127 +1,186 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "data_type.hpp" + +#include "ck/utility/data_type.hpp" namespace ck { namespace tensor_operation { -namespace binary_element_wise { - -template -struct Add; +namespace element_wise { -template <> -struct Add +struct Add { + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> __host__ __device__ constexpr void - operator()(double& dst, const double& src1, const double& src2) const + operator()(float& y, const float& x0, const float& x1) const { - dst = src1 + src2; - } -}; + y = x0 + x1; + }; -template <> -struct Add -{ + template <> __host__ __device__ constexpr void - operator()(float& dst, const float& src1, const float& src2) const + operator()(double& y, const double& x0, const double& x1) const { - dst = src1 + src2; - } -}; + y = x0 + x1; + }; -template <> -struct Add -{ + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + template <> __host__ __device__ constexpr void - operator()(half_t& dst, const half_t& src1, const half_t& src2) const + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const { - dst = src1 + src2; + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); } }; -template <> -struct Add +struct Subtract { + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> __host__ __device__ constexpr void - operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const + operator()(float& y, const float& x0, const float& x1) const { - const float x1 = ck::type_convert(src1); - const float x2 = ck::type_convert(src2); - const float y = x1 + x2; - dst = ck::type_convert(y); - } -}; + y = x0 - x1; + }; -template -struct Substract; + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 - x1; + }; -template <> -struct Substract -{ + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 - x1; + }; + + template <> __host__ __device__ constexpr void - operator()(double& dst, const double& src1, const double& src2) const + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const { - dst = src1 - src2; + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp - x2_tmp; + y = ck::type_convert(y_tmp); } }; -template <> -struct Substract +struct Bilinear { + Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; + + template <> __host__ __device__ constexpr void - operator()(float& dst, const float& src1, const float& src2) const + operator()(float& y, const float& x0, const float& x1) const { - dst = src1 - src2; - } + y = alpha_ * x0 + beta_ * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); + }; + + float alpha_; + float beta_; }; -template <> -struct Substract +struct AddRelu { + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> __host__ __device__ constexpr void - operator()(half_t& dst, const half_t& src1, const half_t& src2) const + operator()(float& y, const float& x0, const float& x1) const { - dst = src1 - src2; - } + const float a = x0 + x1; + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > 0.0 ? a : 0.0; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; }; -template <> -struct Substract +struct AddHardswish { + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> __host__ __device__ constexpr void - operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const + operator()(float& y, const float& x0, const float& x1) const { - const float x1 = ck::type_convert(src1); - const float x2 = ck::type_convert(src2); - const float y = x1 - x2; - dst = ck::type_convert(y); - } + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + double a = x0 + x1; + double b = a + 3.0; + double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + float a = x0 + x1; + float b = a + 3.0f; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; }; -} // namespace binary_element_wise +} // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index d899bdc967f57b19f1b3e870ac8f96227da0404a..9c273e750b8cc2bf7b3e625efb0a194a97dcd030 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,109 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "data_type.hpp" -#include "math_v2.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" namespace ck { namespace tensor_operation { namespace element_wise { -struct PassThrough -{ - __host__ __device__ void operator()(float& y, const float& x) const { y = x; } - - __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; } - - __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; } - - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; } - - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } - - __host__ __device__ void operator()(double& y, const double& x) const { y = x; } -}; - -struct Add -{ - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - // FIXME - Use float (acc type) bias in the future. - y = x0 + x1; - } -}; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {} - - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - y = alpha_ * x0 + beta_ * x1; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - // FIXME - Let x0 be acc type - y = static_cast(alpha_ * static_cast(x0) + beta_ * static_cast(x1)); - } - - float alpha_; - float beta_; -}; - -struct AddRelu -{ - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - const float a = x0 + x1; - y = a > 0 ? a : 0; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - const half_t a = x0 + x1; - y = a > 0 ? a : 0; - } -}; - -struct AddHardswish -{ - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - float a = x0 + x1; - float b = a + float{3}; - float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; - y = c; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - float a = x0 + x1; - float b = a + float{3}; - float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; - y = c; - } -}; +// Need to ensure compiler will fail if there is no matching candidate, instead of compiler +// siliently do implicit type conversion +// +// Method 1: +// +// struct ExampleElementwiseOp +// { +// template +// __host__ __device__ constexpr void +// operator()(Y&, const X) const; +// +// template<> +// __host__ __device__ constexpr void +// operator()(half_t& y, const half_t& x) const +// { +// } +// }; +// +// Method 2: +// +// template +// struct ExampleElementwiseOp; +// +// template <> +// struct ExampleElementwiseOp +// { +// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const +// { +// } +// }; struct AddReluAdd { - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const { half_t a = x0 + x1; half_t b = a > 0 ? a : 0; y = b + x2; } - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1, const float& x2) const + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const { float a = x0 + x1; float b = a > 0 ? a : 0; @@ -111,8 +69,9 @@ struct AddReluAdd y = c; } - __host__ __device__ constexpr void - operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const float& x0, const half_t& x1, const half_t& x2) const { float a = x0 + x1; float b = a > 0 ? a : 0; @@ -123,8 +82,14 @@ struct AddReluAdd struct AddHardswishAdd { - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1, const float& x2) const + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const { float a = x0 + x1; float b = a + float{3}; @@ -133,8 +98,9 @@ struct AddHardswishAdd y = d; } - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const { float a = x0 + x1; float b = a + float{3}; @@ -144,206 +110,95 @@ struct AddHardswishAdd } }; -struct Normalize -{ - Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {} - - __host__ __device__ constexpr void operator()(float& y, - const float& x, - const float& mean, - const float& mean_square, - const float& gamma, - const float& beta) const - { - float variance = mean_square - (mean * mean); - y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta; - } - - float epsilon_; -}; - -// Unary operators are usually called element-wisely before/after the reduction is executed on the -// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 - -template -struct UnaryIdentic; - -template <> -struct UnaryIdentic +// C = A * B +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu { - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const { y = x; }; -}; + template + __host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const; -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const { - y = x / type_convert(divider_); - }; + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + const auto fast_gelu = [&](float x) { + const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float emu = exp(-u); + const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + return x * cdf; + }; - int32_t divider_ = 1; -}; + const float y = fast_gelu(c + float(d0) + float(d1)); -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }; + e = type_convert(y); + } }; -template <> -struct UnaryIdentic +struct Normalize { - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; + // FIXME: is double absolutely necessary? + Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {} - __host__ __device__ void operator()(double& y, const double& x) const { y = x; }; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; + template + __host__ __device__ constexpr void operator()( + T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const; - __host__ __device__ void operator()(double& y, const double& x) const + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x, + const float& mean, + const float& mean_square, + const float& gamma, + const float& beta) const { - y = x / type_convert(divider_); - }; - - int32_t divider_ = 1; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; }; - - int32_t divider_ = 1; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }; -}; - -template -struct UnarySquare; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; }; + using ck::math::sqrt; - __host__ __device__ void operator()(float& y, const float& x) const { y = x * x; }; -}; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const - { - y = x * x / type_convert(divider_); + float variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + type_convert(epsilon_))) * gamma + beta; }; - int32_t divider_ = 1; -}; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const { y = x * x; }; -}; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const + template <> + __host__ __device__ constexpr void operator()(double& y, + const double& x, + const double& mean, + const double& mean_square, + const double& gamma, + const double& beta) const { - y = x * x / type_convert(divider_); - }; + using ck::math::sqrt; - int32_t divider_ = 1; -}; - -template -struct UnaryAbs; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); }; -}; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); }; -}; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); }; -}; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; + double variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta; + }; - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); }; + // FIXME: is double absolutely necessary? + double epsilon_; }; template -struct UnarySqrt; +struct UnaryTypeConvert; template <> -struct UnarySqrt +struct UnaryTypeConvert { - __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); }; + __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const + { + y = ck::type_convert(x); + } }; template <> -struct UnarySqrt +struct UnaryTypeConvert { - __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const + __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const { - y = ck::math::sqrt(x); - }; + y = ck::type_convert(x); + } }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp deleted file mode 100644 index 038e36f564d4bc87234a715e6b6073de5f51e9e6..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once -#include "data_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace element_wise { - -} // namespace element_wise -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..440f0de4d4e7a6cf6bfb63ac7160b8d0a6004017 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct PassThrough +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x; + }; +}; + +struct Scale +{ + __host__ __device__ Scale(float scale) : scale_(scale) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = scale_ * x; + }; + + float scale_; +}; + +struct UnaryDivide +{ + __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {} + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +struct UnarySquare +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x * x; + }; +}; + +struct UnaryAbs +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::abs(x); + }; +}; + +struct UnarySqrt +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sqrt(x); + }; +}; + +struct Relu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + y = x > 0 ? x : 0; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_f32 = ck::type_convert(x); + float y_f32 = x_f32 > 0 ? x_f32 : 0; + y = ck::type_convert(y_f32); + } +}; + +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +struct FastGelu +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float emu = exp(-u); + const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + + y = x * cdf; + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 792060ca86299050d13e56e54d3e6d9906ece192..498a88afe0dcb546f1e11b1d8df06fa346aac1fe 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,10 +1,12 @@ -#ifndef UTILITY_BLOCK_TO_CTILE_MAP -#define UTILITY_BLOCK_TO_CTILE_MAP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "utility/math.hpp" -#include "utility/number.hpp" -#include "tensor_description/tensor_adaptor.hpp" -#include "tensor_description/multi_index_transform_helper.hpp" +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" namespace ck { @@ -485,5 +487,3 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, } } // namespace ck - -#endif // UTILITY_BLOCK_TO_CTILE_MAP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index b2f06c03c689734475f4a29dc17972cb6a7fa373..6836a66047531d9eef09dbffc58188f7696a783a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -1,39 +1,15 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP -#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_accumulate.hpp" -#include "reduction_functions_blockwise.hpp" -#include "reduction_functions_threadwise.hpp" - -#include "threadwise_tensor_slice_transfer.hpp" -#include "element_wise_operation.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -171,15 +147,15 @@ struct GridwiseReduction_mk_to_m_multiblock AccDataType beta, OutDataType* const __restrict__ p_out_value_global) { - const auto identityVal = ReduceOperation::GetIdentityValue(); + const auto identityVal = ReduceOperation::template GetIdentityValue(); // LDS __shared__ AccDataType p_reduce_work_buffer[BlockSize]; - const auto in_global_val_buf = - make_dynamic_buffer(p_in_value_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(identityVal)); + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); auto out_global_val_buf = make_dynamic_buffer( p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); @@ -358,12 +334,12 @@ struct GridwiseReduction_mk_to_m_multiblock __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; - const auto identityVal = ReduceOperation::GetIdentityValue(); + const auto identityVal = ReduceOperation::template GetIdentityValue(); - const auto in_global_val_buf = - make_dynamic_buffer(p_in_value_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(identityVal)); + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); const auto in_global_idx_buf = make_dynamic_buffer( p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); auto out_global_val_buf = make_dynamic_buffer( @@ -635,4 +611,3 @@ struct GridwiseReduction_mk_to_m_multiblock }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index 074aafb9d48acfcd94ccc0ec429a1dad4d68dfde..6c5bd29f9b5ee70b995342f04d59b8ed042b1194 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -1,38 +1,15 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP -#define CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_accumulate.hpp" -#include "reduction_functions_threadwise.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "element_wise_operation.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -135,12 +112,12 @@ struct GridwiseReduction_mk_to_m_threadwise ReduceOperation, PropagateNan>; - const auto identityVal = ReduceOperation::GetIdentityValue(); + const auto identityVal = ReduceOperation::template GetIdentityValue(); - const auto in_global_val_buf = - make_dynamic_buffer(p_in_value_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(identityVal)); + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); auto dst_global_buf = make_dynamic_buffer( p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); @@ -276,12 +253,12 @@ struct GridwiseReduction_mk_to_m_threadwise (void)acc_elementwise_op; - const auto identityVal = ReduceOperation::GetIdentityValue(); + const auto identityVal = ReduceOperation::template GetIdentityValue(); - const auto in_global_val_buf = - make_dynamic_buffer(p_in_value_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(identityVal)); + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); const auto in_global_idx_buf = make_dynamic_buffer( p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); @@ -495,4 +472,3 @@ struct GridwiseReduction_mk_to_m_threadwise }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp index d3342b072e077b0e3a93f1cad1a34ede1b833dd7..2393734826a6b93db72cd0a0657118f24a4e9eb1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp @@ -1,9 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "cluster_descriptor.hpp" -#include "data_type.hpp" -#include "element_wise_operation.hpp" -#include "threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp index 374c4fe59a0e5013220a6c930db44214fac42e7f..d4e7d1421da17d80bc57ed155796deef793dd00d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -1,9 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "cluster_descriptor.hpp" -#include "data_type.hpp" -#include "element_wise_operation.hpp" -#include "threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp index a9b6d8dfa0d955dfb7f05eb9cc0cda2154e0ce86..2369f51795d9a3bcf5a28ad9de702d75e24fac48 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP #define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..22d96a10a2a38c49113a202099757b2afddf28fa --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,995 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_bias_grid, + p_d0_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_bias_grid; + ignore = p_d0_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c1_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const C1ElementwiseOperation& c1_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + // c0 and c1 + constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + + auto c01_thread_buf = make_static_buffer( + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatReduceAcc, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC1, + FloatReduceAcc, + decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatC, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_buf, + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + c1_thread_copy_global_to_vgpr.Run( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_buf, + c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c1_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C1 + c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } // Reduction + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 3b5daf6eadc469388851b0e04bea024dbd361277..ed98b6266f41c7fc5a6e3da4cba8843f39ad8c5c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,15 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_dl_v2r3.hpp" -#include "blockwise_tensor_slice_transfer_v5r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" -#include "element_wise_operation.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp index a7ff81e209479498caa3eabb4d8e8f9ad6ff2760..84e033e1e91bb9873ebed66906c6da5a287bf1bd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP #define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp index 607a05d1561f538c4e15a6c25d4059437b2623c9..b1dfb0c73fc92a19662e01df31090008cdef87ff 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_GEMM_V2_HPP #define CK_GRIDWISE_GEMM_V2_HPP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp index a36b5e53ce0e07de1006f8cb46a07dead56b3a2a..ace844338414668030e5dd6aeeaf559e1a7c1060 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_GEMM_V3_HPP #define CK_GRIDWISE_GEMM_V3_HPP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5ce7db0a9770696bb62aba41f064d0a404d61811 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const StaticallyIndexedArray& + ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different + // type from E + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 20c3a0b618517576e456598a10acd85d9b7269bb..42a56e2a6b741065e1eda6e89081b944cf6cfcc0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3281b910d3d8782fc18f91ee77ce42aa44fc5ba9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v2 +{ + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write 0 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global Read 1 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // move to i + 2 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write i + 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global read i + 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write i + 1 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global read i + 2 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + ++i; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // LDS write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + block_sync_lds(); + + // GEMM num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index c178e2949633e46114da9fab88bd419eaeccde44..8e29b5189ad87db48b1acdc54aa3b6c23439aac7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -1,31 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" -#include "reduction_functions_threadwise.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { template __global__ void @@ -36,17 +41,17 @@ __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - DPtrsGlobal p_ds_grid, + ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, - const DxsInElementwiseOperation dxs_in_element_op, - const DxsAccElementwiseOperation dxs_out_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) @@ -55,32 +60,32 @@ __global__ void GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, - p_ds_grid, + p_reduces_grid, p_shared, a_element_op, b_element_op, c_element_op, - dxs_in_element_op, - dxs_out_element_op, + reduce_in_element_ops, + reduce_out_element_ops, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, - d_grid_desc_mblock_mperblock, + reduce_grid_desc_mblock_mperblock, block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = p_ds_grid; + ignore = p_reduces_grid; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; - ignore = dxs_in_element_op; - ignore = dxs_out_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; ignore = a_grid_desc_ak0_m_ak1; ignore = b_grid_desc_bk0_n_bk1; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = d_grid_desc_mblock_mperblock; + ignore = reduce_grid_desc_mblock_mperblock; ignore = block_2_ctile_map; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -90,19 +95,19 @@ template {}))), make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - return d_grid_desc_mblock_mperblock; + return reduce_grid_desc_mblock_mperblock; } // return block_id to C matrix tile idx (m0, n0) mapping @@ -313,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using DGridDescriptor_MBlock_MPerBlock = - remove_cvref_t; + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - DPtrsGlobal p_ds_grid, - void* __restrict__ p_shared, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const DxsInElementwiseOperation& dxs_in_element_op, - const DxsAccElementwiseOperation& dxs_out_element_op, - const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, - const Block2CTileMap& block_2_ctile_map) + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -701,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - // VGPR d_reduce_thread_desc_mperblock - constexpr auto d_reduce_thread_desc_mperblock = + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = make_naive_tensor_descriptor_packed(make_tuple(Number{})); - // VGPR d_reduce_thread_desc_mblock_mperblock - constexpr auto d_reduce_thread_desc_mblock_mperblock = + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); auto c_reduce_thread_buf = make_static_buffer( @@ -735,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 1, true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; - auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple( + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( [&](auto I) { - auto p_d_grid = p_ds_grid[I]; - auto d_out_element_op = dxs_out_element_op[I]; + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; return ThreadwiseTensorSliceTransfer_v1r3< FloatReduceAcc, - remove_pointer_t, - decltype(d_reduce_thread_desc_mblock_mperblock), - decltype(d_grid_desc_mblock_mperblock), - decltype(d_out_element_op), + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), Sequence<1, mreduce_per_thread>, Sequence<0, 1>, 1, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, - DGlobalMemoryDataOperation::At(I), + ReduceGlobalMemoryDataOperation::At(I), 1, - false>{d_grid_desc_mblock_mperblock, + false>{reduce_grid_desc_mblock_mperblock, make_multi_index(block_work_idx[I0], // mblock c_reduce_thread_data_idx_begin[I0]), // mperblock - d_out_element_op}; + reduce_acc_element_op}; }, - Number{}); + Number{}); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -792,34 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_tuple(I0, I0), c_reduce_thread_buf); - static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { - auto& p_d_grid = p_ds_grid[In]; + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; - auto d_grid_buf = make_dynamic_buffer( - p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); - auto d_thread_buf = + auto reduce_thread_buf = make_static_buffer( - d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + reduce_thread_desc_mperblock.GetElementSpaceSize()); - auto& d_in_element_op = dxs_in_element_op[In]; + auto& reduce_in_element_op = reduce_in_element_ops[In]; - auto& d_reduce_thread_copy_vgpr_to_global = - dxs_reduce_thread_copy_vgpr_to_global(In); + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); - using DReduceOperation = remove_cvref_t; + using ReduceOperation = remove_cvref_t; using ThreadwiseReduce = ThreadwiseReduction; // Global write Gemm shuffle + reduction - const auto d_identityVal = DReduceOperation::GetIdentityValue(); + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); static_for<0, mreduce_per_thread, 1>{}( - [&](auto I) { d_thread_buf(I) = d_identityVal; }); + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR static_for<0, mreduce_per_thread, 1>{}([&](auto im) { @@ -828,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Number{}; - d_in_element_op(c_reduce_thread_buf(offset), - c_reduce_thread_buf(offset)); + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); }); }); - ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf); + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); // copy from VGPR to Global - d_reduce_thread_copy_vgpr_to_global.Run( - d_reduce_thread_desc_mblock_mperblock, - make_tuple(I0, I0), - d_thread_buf, - d_grid_desc_mblock_mperblock, - d_grid_buf); + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); if constexpr(access_id < num_access - 1) { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( - d_grid_desc_mblock_mperblock, + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, make_tuple(c_global_step[I0], c_global_step[I1])); } }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 55390dbc864e24bc1f8411dee3897e36aff5c900..3fa6c10e099cac9bb20ce9a46da6831209aeccf1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -1,14 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -129,7 +135,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = +#if 1 + remove_cvref_t())>; +#else + GridwiseGemmPipeline_v2; +#endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -420,8 +433,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1_Selector(); + static_assert(std::is_default_constructible_v); + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b8286cfc4c1a8dad912a8335896975229e8e3f8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // TODO ANT: separate into MMA + Epilogue + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_bias_grid, + p_c0_add_grid, + p_c0_gamma_grid, + p_c0_beta_grid, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + block_2_ctile_map); + + // TODO ANT: Run layernorm epilogue here +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_bias_grid; + ignore = p_c0_add_grid; + ignore = p_c0_gamma_grid; + ignore = p_c0_beta_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +template +struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + // LDS allocation for reduction workspace + constexpr index_t c_lds_workspace_size = BlockSize; + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size_aligned * sizeof(FloatCShuffle) + + c_lds_workspace_size * sizeof(FloatReduceAcc)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // in order to reduce N dim without elaborate sync across CUs in single kernel, one + // workgroup must span the entire N extent + if(math::integer_divide_ceil(N, NPerBlock) > 1) + { + return false; + } + + // static check: all waves in the workgroups combined must cover whole N extent in order + // to have efficient N-dim reduction + static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, + "condition not met for efficient layernorm"); + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // for bias, beta, gamma + __host__ __device__ static constexpr auto + MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n) + { + const auto N = c0_grid_desc_n.GetLength(I0); + const auto NBlock = N / NPerBlock; + + const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_n, + make_tuple(make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return c0_grid_desc_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_bias_grid_buf = make_dynamic_buffer( + p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_grid_buf = make_dynamic_buffer( + p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_gamma_grid_buf = make_dynamic_buffer( + p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + auto c0_beta_grid_buf = make_dynamic_buffer( + p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0); + + // for broadcasting bias, beta, gamma + const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_nblock_nperblock, + make_tuple(make_insert_transform(I1), + make_insert_transform(I1), + make_pass_through_transform(NBlock), + make_pass_through_transform(NPerBlock)), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // pytorch default + // https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html + static constexpr FloatReduceAcc epsilon = 1e-5; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // TODO: this should be implemented as a blockwise reduction + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + auto d_reduce_work_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + c_block_size_aligned), + BlockSize); + + // Sum thread workspace + auto d0_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // Squared sum thread workspace + auto d1_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatCShuffle, + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_block_desc_mperblock_nperblock), + tensor_operation::element_wise::PassThrough, + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, + c_reduce_thread_data_idx_begin, + tensor_operation::element_wise::PassThrough{}}; + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + block_sync_lds(); + + // load from LDS and global, add bias + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_bias_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + acc_element_op(out, + c_reduce_thread_buf(i) + + static_cast(c0_thread_buf(i))); + c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias) + }); + + c0_add_thread_copy_global_to_vgpr.Run( + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_add_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // add + }); + + // layernorm + { + using ThreadwiseReduceD0 = + ThreadwiseReduction; + using ThreadwiseReduceD1 = + ThreadwiseReduction; + + const auto d0_zeroVal = + ThreadwiseReduceD0::Op::template GetIdentityValue(); + const auto d1_zeroVal = + ThreadwiseReduceD1::Op::template GetIdentityValue(); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d1_thread_buf(i) = d1_zeroVal; }); + + // reduce sum in VGPR + ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf); + + // reduce squared sum in VGPR + ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf); + + // reduce within workgroup + using BlockwiseReduce = PartitionedBlockwiseReduction< + FloatReduceAcc, + BlockSize, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K + Sequence<1, 0>, // ThreadClusterArrangeOrder + reduce::Add, + false>; + + static_for<0, mreduce_per_thread, 1>{}([&](auto i) { + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d0_thread_buf(i)); // blockwise reduced sum + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d1_thread_buf(i)); // blockwise reduced squared sum + }); + + // normalize + const index_t NRaw = + c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0] + .GetUpperLengths()[I1]; // TODO: proper handle + + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto dst_offset = + Number{}; + + constexpr auto src_offset = + Number{}; + + FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; + + FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; + FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; + FloatReduceAcc divisor_sqrt; + tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); + + c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; + }); + }); + + // scaling + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_gamma_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) *= + static_cast(c0_thread_buf(i)); // * gamma + }); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_beta_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // + beta + }); + + block_sync_lds(); + + c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf, + c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf); + + } // end layernorm + + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0_add + c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 0d3f8ddefb2fd47831edd1b81b0d7b744106c32a..3bb3774afa8a8feab5c8b80b261b0a66281a6e02 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -1,15 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -791,8 +795,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + void* p_shared = static_cast(p_shared_block); + auto c_block_buf = make_dynamic_buffer( - static_cast(p_shared_block), + static_cast(p_shared), c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); static_assert(M1 == MWave, ""); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 974455fa3b777edc34d93863c9b6df24518084c1..847bfd47cf78ebdc442dbd53ea2f7e5c78b6f100 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,14 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index a54906cfbc5213eaf9c9eed43feb0fb5cdcf9332..949d56483669946f12f40adca568d4e2a6b0bb5b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,14 +1,18 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -607,7 +611,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 c_grid_buf); } } -}; // namespace ck +}; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index dbff1577e1f248d49b63348da6b3c5368b4a1eae..84e1af0a356856ca902be9beb465358f136398ea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,15 +1,19 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -717,7 +721,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 }); } } -}; // namespace ck +}; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 2828655f51280c44f210593b6a7488fd57bfa702..71bf05ce21d5fde3c919c1c5b9beb3a937351611 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,15 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" -#include "tensor_space_filling_curve.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 3a7a551181b8cdd50efa37625349a4278a0c1839..35f3bdeff70c0522d54505ddbfd645cb2f4593b7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,16 +1,19 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r2.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -755,4 +758,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 2e324faf1330d129d4c5f001e4e47631368b64f2..4e4ab9c9e83fcbe84d49f70f531faa64db8f1399 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,14 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r3.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp index 6d95aec93849fa8b89e0932d81b96fdb0c4855bc..1e52b4057c99d61b0e2093b69c934737ffef7c46 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -1,32 +1,9 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_SET_BUFFER_VALUE_HPP -#define CK_GRIDWISE_SET_BUFFER_VALUE_HPP - -#include "threadwise_tensor_slice_transfer.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { @@ -37,7 +14,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe { - using PassThroughOp = tensor_operation::element_wise::UnaryIdentic; + using PassThroughOp = tensor_operation::element_wise::PassThrough; constexpr auto I0 = Number<0>{}; @@ -77,4 +54,3 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98b29ff82e002572b7d0e87cc281838c79c13060 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, + const GridDesc_M_K out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_k, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); +}; + +template +struct GridwiseSoftmax_mk_to_mk +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (KThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k, + const GridDesc_M_K& out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer + out_thread_buf; + + StaticBuffer max_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // Normally, 0 as invalid element value is adequate since 0 makes no contribution to + // accumulated result. However, in stable softmax, all values 0s or not are subtracted by + // another value_max. As numbers become non-zero, effectively it allows invalid values to + // slip through and contribute to the accumulated result. + // + // The trick here is leveraging the fact that many math functions (add, sub, exp, ...) + // propagate NaNs when operands have NaNs involved. By initialiing invalid element value + // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still + // be identified as an invalid value. We can then discard the invalid values which + // originally failed the bound check during accumulation. This allows to ignore values that + // failed bound check even after multiple math manipulations. + // + // NOTE: reset coordinate after every step because the same threadwise copy will sweep + // through global memory 3 times back and forth + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2( + out_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3( + out_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + constexpr auto in_thread_copy_fwd_step = + make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); + constexpr auto in_thread_copy_bwd_step = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + /// + /// max(x) + /// + using BlockwiseMaxReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Max, + false, // param ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseMaxReduce = + ThreadwiseReduction>; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize()); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + /// + /// sum(exp(x - max(x))) + /// + using BlockwiseSumReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Add, + false, // ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseSumReduce = + ThreadwiseReduction>; + + reducedTiles = 0; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + // do element-wise pre-reduction operation + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); + }); + }); + + ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I)); + // block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + /// + /// softmax + /// + reducedTiles = 0; + if(float_equal_zero{}(beta)) + { + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + StaticBuffer + in_prior_dst_buf; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + threadwise_dst_load.Run(out_grid_desc_m_k, + out_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_prior_dst_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * in_prior_dst_buf(Number{}); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e7fbbc6c6fa86753667ad6720c8c762f4604838 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor); +} + +template +struct GridwiseUnaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0) + { + return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size) + { + const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector); + + return grid_size; + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_store_global_offset}; + + auto b_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + b_grid_desc_m0, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = b_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(b_thread_buf(Number{}), a_thread_buf(Number{})); + }); + + b_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + b_thread_buf, + b_grid_desc_m0, + b_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index 3dcfe3a030969eda4897663c8d3712ea71cfa8f7..188c62d93b0d60317ac139115b79f496612d2ea6 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -1,32 +1,9 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP -#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP - -#include "reduction_functions_accumulate.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_functions_accumulate.hpp" namespace ck { @@ -39,7 +16,9 @@ template + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithNanCheck> struct ThreadwiseReduction { static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; @@ -51,7 +30,7 @@ struct ThreadwiseReduction static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); - using Accumulation = detail::AccumulateWithNanCheck; + using Op = OpReduce; template __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) @@ -73,12 +52,15 @@ struct ThreadwiseReduction // 2) DstDesc is known at compile-time // 3) SrcBuffer is static buffer // 4) DstBuffer is static buffer -template +template < + typename AccDataType, + typename IndexDataType, + typename SrcThreadDesc_M_K, + typename DstThreadDesc_M, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> struct ThreadwiseReductionWithIndex { static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; @@ -90,9 +72,6 @@ struct ThreadwiseReductionWithIndex static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); - using Accumulation = - detail::AccumulateWithIndexAndNanCheck; - template ::type = false> struct ThreadwiseTensorSliceTransfer_v2 { + static_assert((InvalidElementAsNaN && !std::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; @@ -316,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + i * src_scalar_step_in_vector); - dst_buf(Number{}) = - type_convert(src_vector.template AsType()[i]); + if constexpr(InvalidElementAsNaN) + { + dst_buf(Number{}) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } }); if constexpr(idx_1d.value != num_access - 1) @@ -1168,4 +1185,3 @@ struct ThreadwiseTensorSliceTransfer_v4 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 4cd41ddb30dba449d7a3fb548136907e588e88fd..005f35e909611c5e6ae0a5393a0cfe38b8a37948 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -1,10 +1,12 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "static_tensor.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor/static_tensor.hpp" namespace ck { @@ -789,4 +791,3 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp index 1447f06f0223bb2bd2703d3e549d95815f28e622..6a73466efa443af294f2062ee5173cf548880c3d 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP #define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 2504c9285670e7c2013187169a68797958a3db87..6e8a23930bbb91677ee18bab216af6c45de72e4c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -1,9 +1,11 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" namespace ck { // Assume: @@ -171,4 +173,3 @@ struct ThreadwiseTensorSliceTransfer_v4r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index f0e9c7e76143131bf207e530a744d573d0b8b2fa..f13da341f9b2d36d992b1cd1835b43c89dcabd1f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 042bc95f55e2a353ea6a876bf197387c8dc6280e..9c91cd9ca8f86df37fee6dfdf597be1cefe2d683 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -1,10 +1,12 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_space_filling_curve.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { @@ -206,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1 SrcCoord src_coord_; DstCoord dst_coord_; const ElementwiseOperation element_op_; -}; // namespace ck +}; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index ae85ba91e5841532c5254c97ce39b89e6038f1d3..68bc2726f4be6fa1b08fe9213e0db9440526f975 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -1,10 +1,12 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_space_filling_curve.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { @@ -256,4 +258,3 @@ struct ThreadwiseTensorSliceTransfer_v6r2 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index 47024d5e688d5ec8eeab3c9e6600c270529c4877..0f5fb88b04540dbf355fb0f02998303b7680c193 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -1,10 +1,12 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_space_filling_curve.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { @@ -306,4 +308,3 @@ struct ThreadwiseTensorSliceTransfer_v6r3 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2eb1b0ee90a446ee3b14354aff018275cf7809f4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = + SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + auto generate_vectors = [&](auto data_types) { + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + }; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(SrcDatas{}); + auto dst_vectors = generate_vectors(DstDatas{}); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + // apply pointwise function + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index a39b795818ed5470580b71b6618ec3e35b8e2402..eaf0f1327511779b4639cbd3cb886c2b890e6f96 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,9 +1,11 @@ -#ifndef CK_XDLOPS_GEMM_HPP -#define CK_XDLOPS_GEMM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "math.hpp" -#include "amd_xdlops.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_xdlops.hpp" namespace ck { @@ -786,4 +788,3 @@ struct XdlopsGemm }; } // namespace ck -#endif diff --git a/include/ck/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp index 3c5939aaf307c60ec006897b3f1a6b132faa5ff3..9f1525914cdee16ee6cbed51297f0008d46f6085 100644 --- a/include/ck/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -1,7 +1,9 @@ -#ifndef CK_AMD_ADDRESS_SPACE_HPP -#define CK_AMD_ADDRESS_SPACE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "c_style_pointer_cast.hpp" // Address Space for AMDGCN @@ -41,4 +43,3 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres } } // namespace ck -#endif diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index cac11e0ca40e5a0daa2355adea1d8fd679f56585..7f0cc35bb32acfe36bfb457b92a1fb3173717820 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "data_type.hpp" @@ -7,6 +10,8 @@ namespace ck { template union BufferResource { + __device__ constexpr BufferResource() : content{} {} + // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions int32x4_t content; diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index f4af03cb75f127b0b51102e06f4738de5beca99f..b5a991e9431a6c342cc72e6a93b4f8dacba3b357 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP diff --git a/include/ck/utility/amd_llvm_intrinsic.hpp b/include/ck/utility/amd_llvm_intrinsic.hpp index e2ebfa6be87e601b403559fe571ebcc96fed9ad8..f84f86ae9634fea256b639b8f2ae00815c205d69 100644 --- a/include/ck/utility/amd_llvm_intrinsic.hpp +++ b/include/ck/utility/amd_llvm_intrinsic.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_AMD_LLVM_INTRINSIC_HPP #define CK_AMD_LLVM_INTRINSIC_HPP diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 7269412208b566885a3c34d8791a591fe6b14387..07b14338e8b879ff3b3e2104b292028d4f21b7fe 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_AMD_XDLOPS_HPP #define CK_AMD_XDLOPS_HPP diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 4c9dfd9a934894d6ef230383aaf37f2d9ac1774f..370a457fe9d97a621e6f91e78f14733398ebf756 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP diff --git a/include/ck/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp index f692fb51430b7f5773f596e9bd87068f645886c1..9b8d5b95e9f6241b574104313b5b8fa2984ab18d 100644 --- a/include/ck/utility/array_multi_index.hpp +++ b/include/ck/utility/array_multi_index.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp index 8acf5790c67be55983021d79e7ce61c8c22ca999..6e8b0081587afcb2db159a05ecd5fb940def68ff 100644 --- a/include/ck/utility/c_style_pointer_cast.hpp +++ b/include/ck/utility/c_style_pointer_cast.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index e1b03285306cedaadbd9cee8122066c344ae703f..f2d1dcf670ae75a6875efd8e9627f7a1a6e23dea 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "config.hpp" #include "array.hpp" @@ -31,19 +34,49 @@ #include "thread_group.hpp" #include "debug.hpp" -#include "amd_buffer_addressing.hpp" -#include "generic_memory_space_atomic.hpp" -#include "get_id.hpp" -#include "synchronization.hpp" -#include "amd_address_space.hpp" -#include "static_buffer.hpp" -#include "dynamic_buffer.hpp" +#include "ck/ck.hpp" +#include "ck/utility/array.hpp" +#include "ck/utility/container_helper.hpp" +#include "ck/utility/statically_indexed_array.hpp" +#include "ck/utility/container_element_picker.hpp" +#include "ck/utility/multi_index.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/utility/functional4.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/magic_division.hpp" +#include "ck/utility/c_style_pointer_cast.hpp" +#include "ck/utility/is_known_at_compile_time.hpp" +#include "ck/utility/transpose_vectors.hpp" +#include "ck/utility/inner_product.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/debug.hpp" + +#include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/generic_memory_space_atomic.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/synchronization.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/static_buffer.hpp" +#include "ck/utility/dynamic_buffer.hpp" // TODO: remove this #if CK_USE_AMD_INLINE_ASM -#include "amd_inline_asm.hpp" +#include "ck/utility/amd_inline_asm.hpp" #endif #ifdef CK_USE_AMD_MFMA -#include "amd_xdlops.hpp" +#include "ck/utility/amd_xdlops.hpp" #endif diff --git a/include/ck/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp index 54915125ac01e4f45cb3cc5cb814137899a31a7f..abc5185e04a5079edb95fe15bfcd6788256d384a 100644 --- a/include/ck/utility/container_element_picker.hpp +++ b/include/ck/utility/container_element_picker.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index a92e79908d967a620e2eaada1c73dc837cc3de04..c8b02bc5acaf8542bf028fff39c42061db3b3296 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 98f436f347219e245ef2f29c26ab023e733f4a67..7388b2a9c69693b56804fd69830d283bd9e82fdd 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "statically_indexed_array.hpp" + +#include "ck/utility/statically_indexed_array.hpp" #ifdef CK_NOGPU #include "half.hpp" #endif @@ -942,14 +946,14 @@ using int8x64_t = typename vector_type::type; // Convert X to Y template -__host__ __device__ Y type_convert(X x) +__host__ __device__ constexpr Y type_convert(X x) { return static_cast(x); } // convert bfp16 to fp32 template <> -inline __host__ __device__ float type_convert(bhalf_t x) +inline __host__ __device__ constexpr float type_convert(bhalf_t x) { union { @@ -962,7 +966,7 @@ inline __host__ __device__ float type_convert(bhalf_t x) // convert fp32 to bfp16 template <> -inline __host__ __device__ bhalf_t type_convert(float x) +inline __host__ __device__ constexpr bhalf_t type_convert(float x) { union { @@ -1014,6 +1018,11 @@ struct NumericLimits __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } + + __host__ __device__ static constexpr T QuietNaN() + { + return std::numeric_limits::quiet_NaN(); + } }; template <> @@ -1022,12 +1031,15 @@ struct NumericLimits static constexpr unsigned short binary_min = 0x0400; static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } }; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index a2cdb9b8ba40896955b69b8ce4e4cbf95e585d3f..bdfe6579af9774cea4e52e8d88d61a78733395e7 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP #ifndef CK_NOGPU @@ -9,21 +12,27 @@ template struct PrintAsType; template -struct PrintAsType::value>::value> +struct PrintAsType::value>::type> { using type = float; + __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast(p)); } }; template <> struct PrintAsType { using type = float; + __host__ __device__ static void Print(const ck::half_t& p) + { + printf("%.3f ", static_cast(p)); + } }; template -struct PrintAsType::value>::value> +struct PrintAsType::value>::type> { using type = int; + __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast(p)); } }; } // namespace detail @@ -38,7 +47,6 @@ struct PrintAsType::value>::value template __device__ void print_shared(T const* p_shared, index_t num_elements) { - using PrintType = typename detail::PrintAsType::type; constexpr index_t row_elements = row_bytes / sizeof(T); static_assert((element_stride >= 1 && element_stride <= row_elements), "element_stride should between [1, row_elements]"); @@ -60,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) printf("elem %5d: ", i); for(index_t j = 0; j < row_elements; j += element_stride) { - printf("%.0f ", static_cast(p_shared[i + j])); + detail::PrintAsType::Print(p_shared[i + j]); } printf("\n"); diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 5bc52c6a26a3aeb37d800d754da905bcf15f0634..d00ad9a1ff569e0094df62096f583e36e8076272 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "config.hpp" + +#include "ck/ck.hpp" #include "enable_if.hpp" #include "c_style_pointer_cast.hpp" #include "amd_buffer_addressing.hpp" diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 501e1bfc1cbbc8c07ebf3500edb468386b8a7f27..297434b0dddd5f1a680176c3bd099bedfda8dff4 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,7 @@ -#ifndef CK_ENABLE_IF_HPP -#define CK_ENABLE_IF_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { @@ -10,4 +12,3 @@ template using enable_if_t = typename std::enable_if::type; } // namespace ck -#endif diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index b84b617f4493af40eccbd379d5a18d638beb4169..f5721a17ed9370ba0757455c2ed332a1ed5dcc46 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,8 +1,10 @@ -#ifndef CK_FUNCTIONAL_HPP -#define CK_FUNCTIONAL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "type.hpp" +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" namespace ck { @@ -113,4 +115,3 @@ template using conditional_t = typename conditional::type; } // namespace ck -#endif diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index 371182a05e042e37cdb15c48a6f67daae939db61..6f125ca4c944777d02ae1083334bec4602ad68c4 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,8 +1,10 @@ -#ifndef CK_FUNCTIONAL2_HPP -#define CK_FUNCTIONAL2_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "functional.hpp" -#include "sequence.hpp" +#pragma once + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" namespace ck { @@ -45,4 +47,3 @@ struct static_for }; } // namespace ck -#endif diff --git a/include/ck/utility/functional3.hpp b/include/ck/utility/functional3.hpp index 6a400f3ca62ceef18c345ee7e5508e38fc582108..06b67ef7e3fdec25a1ab2e18e23c904342233ebe 100644 --- a/include/ck/utility/functional3.hpp +++ b/include/ck/utility/functional3.hpp @@ -1,10 +1,13 @@ -#ifndef CK_FUNCTIONAL3_HPP -#define CK_FUNCTIONAL3_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "functional.hpp" -#include "functional2.hpp" -#include "sequence.hpp" -#include "multi_index.hpp" +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/multi_index.hpp" namespace ck { @@ -139,4 +142,3 @@ struct ford }; } // namespace ck -#endif diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index b03964438052dc866d348e204940d94ab805c17c..6eeaf15c9b7ac283f7fcac96c195d28f179b6065 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 175aca4073aa4dc6cc9f5018b76d219fac82fe77..63c6d3d7c450a21626c644d9bebfe76f6de6de93 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index 26d1075a538f99b89ab9ef962b812356919455bf..4fd7595cc22e8f74f9c62021304322138822a218 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "config.hpp" + +#include "ck/ck.hpp" #ifndef CK_NOGPU namespace ck { diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp index 8a199159b3e193f80158c41a880f78334ed96419..0172458741354f5ca443c9482d06bc35623f58bc 100644 --- a/include/ck/utility/ignore.hpp +++ b/include/ck/utility/ignore.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_IGNORE_HPP #define CK_IGNORE_HPP diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index ce4f72f4bbf81ec7f873766286c6cd1f27c1446e..558ef4090436f1943f5fa9c0b99fa52714a5f252 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "data_type.hpp" #ifndef CK_NOGPU diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 3d9c0472e7fdd6501003fd6d7886eb8831f89cbb..9aab4e24214a884bdfb391e0f9040bb3af2630ec 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,7 @@ -#ifndef CK_INTEGRAL_CONSTANT_HPP -#define CK_INTEGRAL_CONSTANT_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { @@ -47,4 +49,3 @@ __host__ __device__ constexpr auto operator%(integral_constant, integral_ } } // namespace ck -#endif diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp index dc44027901768226f8a976246d9bc467907fdc3f..8198154422e5d5228b79c6539bddcd5070d1e25c 100644 --- a/include/ck/utility/is_known_at_compile_time.hpp +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -1,7 +1,9 @@ -#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP -#define IS_KNOWN_AT_COMPILE_TIME_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "integral_constant.hpp" #include "sequence.hpp" #include "tuple.hpp" @@ -52,4 +54,3 @@ struct is_known_at_compile_time> }; } // namespace ck -#endif diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 3806cae916ef1deb7fec59fea99e38ba04386488..e48a90bed32e8fc92e233642ae2cc2c35e83a7c2 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -1,7 +1,9 @@ -#ifndef CK_MAGIC_DIVISION_HPP -#define CK_MAGIC_DIVISION_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" @@ -158,5 +160,3 @@ struct MagicDivision }; } // namespace ck - -#endif diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 48438e6179d1bde0a39fc58c1d3e15c3e0396b7e..0cfc2f7da442a862725b6a666ec9686cc3098ec2 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -1,7 +1,9 @@ -#ifndef CK_MATH_HPP -#define CK_MATH_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" @@ -142,6 +144,24 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) return min(x, min(ys...)); } +// disallow implicit type casting +template +__device__ T exp(T x); + +// TODO: add f16 support using v_exp_f16 + +template <> +__device__ float exp(float x) +{ + return __expf(x); +} + +template <> +__device__ double exp(double x) +{ + return exp(x); +} + // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { @@ -212,5 +232,3 @@ struct less } // namespace math } // namespace ck - -#endif diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 438f5e12bdb71ce9aff27cbb3807b6e025bbfb14..fc264117f08e2a39e8c0386f82274f6d83060c51 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,9 +1,12 @@ -#ifndef CK_MATH_V2_HPP -#define CK_MATH_V2_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include -#include "data_type.hpp" -#include "type.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" namespace ck { namespace math { @@ -112,5 +115,3 @@ static inline __device__ double sqrt(double x) { return ::sqrt(x); }; } // namespace math } // namespace ck - -#endif diff --git a/include/ck/utility/multi_index.hpp b/include/ck/utility/multi_index.hpp index f395b5ee7151272ddea0c1aa183e7a26ae4017d1..1d544c0906cae3c61f4d7ce27e74e7636c3919b5 100644 --- a/include/ck/utility/multi_index.hpp +++ b/include/ck/utility/multi_index.hpp @@ -1,5 +1,7 @@ -#ifndef CK_MULTI_INDEX_HPP -#define CK_MULTI_INDEX_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include "common_header.hpp" @@ -8,5 +10,3 @@ #else #include "statically_indexed_array_multi_index.hpp" #endif - -#endif diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp index 97a71f8a4119bdc8781b72f279a33572ea4b7488..f3ca6b61dc6ac08330a6cf1633148bb9bad8cd81 100644 --- a/include/ck/utility/number.hpp +++ b/include/ck/utility/number.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_NUMBER_HPP #define CK_NUMBER_HPP diff --git a/include/ck/utility/print.hpp b/include/ck/utility/print.hpp index d7d58bbb8355e22b761f29b22a79efd1f160975a..eed1ca42c7346e32f4d6924452d568bf431d6c24 100644 --- a/include/ck/utility/print.hpp +++ b/include/ck/utility/print.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_PRINT_HPP #define CK_PRINT_HPP diff --git a/include/ck/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp index a34cfce8377e726ca68af8bdaa548b488468c44f..aceef7b296da7466c484aad8033fe1f949662fca 100644 --- a/include/ck/utility/reduction_common.hpp +++ b/include/ck/utility/reduction_common.hpp @@ -1,32 +1,9 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_COMMON_HPP -#define CK_REDUCTION_COMMON_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "reduction_enums.hpp" +#pragma once + +#include "ck/utility/reduction_enums.hpp" namespace ck { @@ -60,6 +37,4 @@ constexpr __device__ index_t get_shift<1>() return (0); } -}; // end of namespace ck - -#endif +} // namespace ck diff --git a/include/ck/utility/reduction_enums.hpp b/include/ck/utility/reduction_enums.hpp index 9089fd6116c2228506166d5ee0c9e6e4cafedb0f..67856331059cef628a62e4b2223f262b839bbf02 100644 --- a/include/ck/utility/reduction_enums.hpp +++ b/include/ck/utility/reduction_enums.hpp @@ -1,30 +1,7 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_ENUMS_HPP -#define CK_REDUCTION_ENUMS_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { @@ -61,6 +38,4 @@ enum struct IndicesType INDICES_8BIT = 3, }; -}; // end of namespace ck - -#endif +} // namespace ck diff --git a/include/ck/utility/reduction_functions_accumulate.hpp b/include/ck/utility/reduction_functions_accumulate.hpp index 22175c5bcc27c8a88cdc188a47c2c3a8d0b85767..724e5599d6c878ba08bf14429c40c0958ab9f2a2 100644 --- a/include/ck/utility/reduction_functions_accumulate.hpp +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -1,43 +1,37 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_BINOP_HPP -#define CK_REDUCTION_FUNCTIONS_BINOP_HPP - -#include "data_type.hpp" -#include "math_v2.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" namespace ck { namespace detail { +// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs) +template +struct AccumulateWithNanIgnore +{ + __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + if(!ck::math::isnan(currVal)) + { + ReduceOperation{}(accuVal, currVal); + } + }; +}; + template struct AccumulateWithNanCheck; +// Does not check for NaN; does not guarantee NaNs be propagated to result +// e.g., given that max(a, b) = a > b ? a : b +// then max(NaN, 1) returns 1 +// max(1, NaN) returns NaN +// since any comparison involving NaNs returns false template struct AccumulateWithNanCheck { @@ -48,6 +42,7 @@ struct AccumulateWithNanCheck }; }; +// Check for NaN; guarantees NaNs be propagated to result template struct AccumulateWithNanCheck { @@ -116,7 +111,5 @@ struct AccumulateWithIndexAndNanCheck struct Add { - using dataType = T; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Add accumulator!"); - __host__ __device__ static constexpr T GetIdentityValue() { return static_cast(0.0f); }; + a = a + b; + } +}; - __device__ static constexpr bool +struct SquaredAdd +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) { return operation == InMemoryDataOperationEnum::AtomicAdd || operation == InMemoryDataOperationEnum::Set; }; - __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + + a = a + b * b; + } }; -template struct Mul { - using dataType = T; - - __host__ __device__ static constexpr T GetIdentityValue() { return static_cast(1.0f); }; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(1.0f); + }; - __device__ static constexpr bool + __host__ __device__ static constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) { return operation == InMemoryDataOperationEnum::Set; }; - __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Mul accumulator!"); + + a = a * b; + } }; -template struct Max { - using dataType = T; - + template __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits::Lowest(); }; - __device__ static constexpr bool + __host__ __device__ static constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) { // ToChange: atomic_max to be added return operation == InMemoryDataOperationEnum::Set; }; + template __host__ __device__ inline constexpr void operator()(T& a, T b) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + if(a < b) a = b; } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + if(a < b) { a = b; @@ -120,28 +153,41 @@ struct Max } }; -template struct Min { - using dataType = T; - - __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits::Max(); }; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return NumericLimits::Max(); + }; - __device__ static constexpr bool + __host__ __device__ static constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) { // ToChange: atomic_min to be added return operation == InMemoryDataOperationEnum::Set; }; + template __host__ __device__ inline constexpr void operator()(T& a, T b) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + if(a > b) a = b; } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + if(a > b) { a = b; @@ -150,28 +196,41 @@ struct Min } }; -template struct AMax { - using dataType = T; - - __host__ __device__ static constexpr T GetIdentityValue() { return static_cast(0.0f); }; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; - __device__ static constexpr bool + __host__ __device__ static constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) { // ToChange: atomic_max to be added return operation == InMemoryDataOperationEnum::Set; }; + template __host__ __device__ inline constexpr void operator()(T& a, T b) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + if(a < b) a = b; } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + if(a < b) { a = b; @@ -181,7 +240,7 @@ struct AMax }; template -T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation) +constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) { T result = ck::type_convert(0.0f); @@ -191,8 +250,43 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation return (result); }; -}; // end of namespace reduce +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = false; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; -} // end of namespace ck +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value; +}; -#endif +} // namespace reduce +} // namespace ck diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index c2adfc5063ff39dece40f351a48291e275408723..97b597221c2850d4b28b644266978fc56b295913 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,10 +1,12 @@ -#ifndef CK_SEQUENCE_HPP -#define CK_SEQUENCE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "type.hpp" -#include "functional.hpp" -#include "math.hpp" +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/math.hpp" namespace ck { @@ -241,7 +243,13 @@ struct arithmetic_sequence_gen } }; - using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type1 = Sequence<>; + + static constexpr bool kHasContent = + (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); + + using type = typename conditional::type; }; // uniform sequence @@ -882,5 +890,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f) return flag; } +template +using sequence_merge_t = typename sequence_merge::type; + +template +using uniform_sequence_gen_t = typename uniform_sequence_gen::type; + } // namespace ck -#endif diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 88d7da63e8a8b1ecc5d675b4aed0c03c34f06a8b..db25c27e70c3653f94524367b8a6bce79480113e 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -1,7 +1,9 @@ -#ifndef CK_SEQUENCE_HELPER_HPP -#define CK_SEQUENCE_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "tuple.hpp" +#pragma once + +#include "ck/utility/tuple.hpp" namespace ck { @@ -33,4 +35,3 @@ __host__ __device__ constexpr auto to_sequence(Tuple...>) } } // namespace ck -#endif diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index df413852019b6f5cea979f6a319761126c7b8a40..1ce29c3ab32827d3b330f951695c3f70aaaf1c59 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_STATIC_BUFFER_HPP #define CK_STATIC_BUFFER_HPP diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index 526be2a07ac8d63cb48e6ac022ac4d22b6b1802e..3438776f413cf010664cc8aa4a18c09bf161fae7 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_STATICALLY_INDEXED_ARRAY_HPP #define CK_STATICALLY_INDEXED_ARRAY_HPP diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index e0ee9d04fdb7e0b63e5f64c8b1d49dd644fa9f8f..bab5aebff78e8c8a52502e746b8cd2242ac9a3bc 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 19887b90ed96f65141929c6f552969a16bbe75c9..e48149bdea0acb94588574577be5b56cbef5d091 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -1,8 +1,9 @@ -#ifndef CK_SYNCHRONIZATION_AMD_HPP -#define CK_SYNCHRONIZATION_AMD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once #ifndef CK_NOGPU -#include "config.hpp" +#include "ck/ck.hpp" namespace ck { @@ -20,4 +21,4 @@ __device__ void block_sync_lds() } // namespace ck #endif -#endif + diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp index 0753edbc02d4a54d53c69a6dc8f6e8feb9e6f09e..2d55e58d52e730d4fb9fd1a7802324555e250e75 100644 --- a/include/ck/utility/thread_group.hpp +++ b/include/ck/utility/thread_group.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #ifndef CK_NOGPU #include "get_id.hpp" diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index 5f70a8daae50ddc0502450c7b20e9954550e18cc..e1995f40c186dc9f403debe04062b65c2044399f 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -1,7 +1,9 @@ -#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP -#define CK_TRANSPOSE_VECTORS_AMD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "statically_indexed_array.hpp" #include "data_type.hpp" @@ -167,4 +169,3 @@ struct transpose_vectors } // namespace ck #endif -#endif \ No newline at end of file diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 766a78240bd376c6f506d0686ed987d94cc64499..07bf721d54b62ce475dd0d44246e833d982f15df 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -1,10 +1,12 @@ -#ifndef CK_TUPLE_HPP -#define CK_TUPLE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "sequence.hpp" -#include "type.hpp" -#include "enable_if.hpp" +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/enable_if.hpp" namespace ck { @@ -17,14 +19,18 @@ struct TupleElementKey }; template -struct TupleElement +struct TupleElementKeyData { - __host__ __device__ constexpr TupleElement() = default; +#if 0 // workaround compiler complaint about implicitly-deleted default constructor + __host__ __device__ constexpr TupleElementKeyData() = default; +#else + __host__ __device__ constexpr TupleElementKeyData() : mData{} {} +#endif - template < - typename T, - typename enable_if, TupleElement>::value, bool>::type = false> - __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) + template , TupleElementKeyData>::value, + bool>::type = false> + __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward(v)) { } @@ -32,20 +38,21 @@ struct TupleElement }; template -__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement& x) +__host__ __device__ constexpr const Data& +get_tuple_element_data(const TupleElementKeyData& x) { return static_cast(x.mData); } template -__host__ __device__ constexpr Data& get_tuple_element(TupleElement& x) +__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData& x) { return x.mData; } // TODO: not sure the use of reference is correct template -__host__ __device__ constexpr Data&& get_tuple_element(TupleElement&& x) +__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData&& x) { return static_cast(x.mData); } @@ -54,7 +61,7 @@ template struct TupleImpl; template -struct TupleImpl, Xs...> : TupleElement, Xs>... +struct TupleImpl, Xs...> : TupleElementKeyData, Xs>... { __host__ __device__ constexpr TupleImpl() = default; @@ -63,13 +70,13 @@ struct TupleImpl, Xs...> : TupleElement, Xs> !is_same, TupleImpl>::value, bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) - : TupleElement, Xs>(std::forward(y))... + : TupleElementKeyData, Xs>(std::forward(y))... { } template = 2, bool>::type = false> __host__ __device__ constexpr TupleImpl(Ys&&... ys) - : TupleElement, Xs>(std::forward(ys))... + : TupleElementKeyData, Xs>(std::forward(ys))... { static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), "wrong! inconsistent size"); @@ -78,15 +85,15 @@ struct TupleImpl, Xs...> : TupleElement, Xs> __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } template - __host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey) const + __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey) const { - return get_tuple_element>(*this); + return get_tuple_element_data>(*this); } template - __host__ __device__ constexpr auto& GetElementByKey(TupleElementKey) + __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) { - return get_tuple_element>(*this); + return get_tuple_element_data>(*this); } }; @@ -121,7 +128,7 @@ struct Tuple : detail::TupleImpl) const { static_assert(I < base::Size(), "wrong! out of range"); - return base::GetElementByKey(detail::TupleElementKey{}); + return base::GetElementDataByKey(detail::TupleElementKey{}); } // write access @@ -129,7 +136,7 @@ struct Tuple : detail::TupleImpl) { static_assert(I < base::Size(), "wrong! out of range"); - return base::GetElementByKey(detail::TupleElementKey{}); + return base::GetElementDataByKey(detail::TupleElementKey{}); } // read access @@ -159,6 +166,31 @@ struct Tuple : detail::TupleImpl +struct Tuple<> +{ + __host__ __device__ constexpr Tuple() = default; + + __host__ __device__ static constexpr index_t Size() { return 0; } + + template + __host__ __device__ constexpr auto operator=(const T&) + { + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template +struct tuple_element +{ + using type = decltype(TTuple{}.At(Number{})); +}; + +template +using tuple_element_t = typename tuple_element::type; + template __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { @@ -173,4 +205,3 @@ constexpr Tuple tie(Args&... args) noexcept } } // namespace ck -#endif diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 4e5b9cf97c8e9d75244ef6162676680651aaa602..6f5b142a5e7a765807bd9d7f556f7b8afc512d37 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -1,5 +1,7 @@ -#ifndef CK_TUPLE_HELPER_HPP -#define CK_TUPLE_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include "functional4.hpp" #include "tuple.hpp" @@ -20,6 +22,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) +template +__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& tx, + const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return Tuple{std::forward(zs)...}; }, + tx, + ty); +} + namespace detail { template @@ -66,4 +79,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, } } // namespace ck -#endif diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index ee3189ebe5fc1ca7367f5c3a3e4c56f4dfc83939..90b9df2950b7d979f8cef386d155c72f6d7a39a5 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,9 +1,11 @@ -#ifndef CK_TYPE_HPP -#define CK_TYPE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" -#include "integral_constant.hpp" -#include "enable_if.hpp" +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/enable_if.hpp" namespace ck { @@ -56,4 +58,3 @@ __host__ __device__ constexpr Y bit_cast(const X& x) } } // namespace ck -#endif diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt index 56c95d6676a06daf89e7f7b247f52739a0514603..59575bdba6ec11be61870219c521fd34cd1a09a9 100644 --- a/library/CMakeLists.txt +++ b/library/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(src/host_tensor) add_subdirectory(src/tensor_operation_instance/gpu) +add_subdirectory(src/host_tensor) add_subdirectory(src/utility) add_subdirectory(src/tensor_operation_instance/cpu) diff --git a/library/include/ck/library/host/host_interface.hpp b/library/include/ck/library/host/host_interface.hpp deleted file mode 100644 index 955da0f4bee3a43cd38446b485cac3d6d59f68ab..0000000000000000000000000000000000000000 --- a/library/include/ck/library/host/host_interface.hpp +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include -#include - -#include "stream_config.hpp" -#include "config.hpp" -#include "device_base.hpp" - -struct DeviceConvFwdPtr_t -{ - using BaseArgument = ck::tensor_operation::device::BaseArgument; - using BaseInvoker = ck::tensor_operation::device::BaseInvoker; - - struct DeviceConvFwdPtrImpl; - std::unique_ptr pImpl; - DeviceConvFwdPtr_t(); - ~DeviceConvFwdPtr_t(); - DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&); - DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&); - DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete; - DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete; - std::unique_ptr - MakeArgumentPointer(void* in_ptr, - void* wei_ptr, - void* out_ptr, - size_t N, - size_t K, - size_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) - const; // in,wei and out element ops are ignored for now since even if we change them, they - // cant be linked - std::unique_ptr - MakeInvokerPointer() const; // requires including BaseInvoker headers - std::string GetTypeString(); - bool IsSupportedArgument(const BaseArgument* arg_ptr); -}; - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t( - std::vector& instances); -void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t( - std::vector& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t( - std::vector& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t( - std::vector& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t( - std::vector& instances); diff --git a/library/include/ck/library/host_tensor/conv_common.hpp b/library/include/ck/library/host_tensor/conv_common.hpp index b60af7d664f327e029fc0b80cb24876e42c36efe..6fad9f7d77d349853e63512ae486a57e7ceb7b6f 100644 --- a/library/include/ck/library/host_tensor/conv_common.hpp +++ b/library/include/ck/library/host_tensor/conv_common.hpp @@ -1,7 +1,9 @@ -#ifndef CONV_COMMON_HPP -#define CONV_COMMON_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "tensor_descriptor.hpp" +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" template -inline auto activ(T v, const ck::ActivTypeEnum activ_type) -{ - const T alpha = 0.3; - switch(activ_type) - { - case ck::ActivTypeEnum::None: return v; - case ck::ActivTypeEnum::LeakyRelu: return (v >= 0 ? v : alpha * v); - case ck::ActivTypeEnum::Sigmoid: return (1 / (1 + exp(-v))); - default: throw std::runtime_error("unsupported activ type"); break; - } -} - -#endif diff --git a/library/include/ck/library/host_tensor/device.hpp b/library/include/ck/library/host_tensor/device.hpp deleted file mode 100644 index 883f3ecc46d59d38b470a661f1bc70c2bd441c9a..0000000000000000000000000000000000000000 --- a/library/include/ck/library/host_tensor/device.hpp +++ /dev/null @@ -1,183 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "ck/options.hpp" -#ifndef CK_NOGPU -#include -#include -#endif -#include "stream_config.hpp" - -#ifndef CK_NOGPU -template -__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) -{ - for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) - { - p[i] = x; - } -} - -inline void hip_check_error(hipError_t x) -{ - if(x != hipSuccess) - { - std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ - << "in function: " << __func__; - throw std::runtime_error(ss.str()); - } -} - -struct DeviceMem -{ - DeviceMem() = delete; - DeviceMem(std::size_t mem_size); - void* GetDeviceBuffer(); - std::size_t GetBufferSize(); - void ToDevice(const void* p); - void FromDevice(void* p); - void SetZero(); - template - void SetValue(T x) - { - if(mMemSize % sizeof(T) != 0) - { - throw std::runtime_error("wrong! not entire DeviceMem will be set"); - } - - set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); - } - ~DeviceMem(); - - void* mpDeviceBuf; - std::size_t mMemSize; -}; - -struct KernelTimerImpl; - -struct KernelTimer -{ - KernelTimer(); - ~KernelTimer(); - void Start(); - void End(); - float GetElapsedTime() const; - - std::unique_ptr impl; -}; - -using device_stream_t = hipStream_t; - -template -float launch_and_time_kernel(const StreamConfig& stream_config, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) -{ -#if CK_TIME_KERNEL - if(stream_config.time_kernel_) - { - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - const int nrepeat = 10; - - printf("Warm up 1 time\n"); - - // warm up - kernel<<>>(args...); - - printf("Start running %d times...\n", nrepeat); - - KernelTimer timer; - timer.Start(); - - for(int i = 0; i < nrepeat; ++i) - { - kernel<<>>(args...); - } - - timer.End(); - - return timer.GetElapsedTime() / nrepeat; - } - else - { - kernel<<>>(args...); - - return 0; - } -#else - kernel<<>>(args...); - - return 0; -#endif -} -#endif - -struct DeviceAlignedMemCPU -{ - DeviceAlignedMemCPU() = delete; - DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment); - void* GetDeviceBuffer(); - std::size_t GetBufferSize(); - void ToDevice(const void* p); - void FromDevice(void* p); - void SetZero(); - ~DeviceAlignedMemCPU(); - - void* mpDeviceBuf; - std::size_t mMemSize; - std::size_t mAlignment; -}; - -struct WallTimerImpl; - -struct WallTimer -{ - WallTimer(); - ~WallTimer(); - void Start(); - void End(); - float GetElapsedTime() const; - - std::unique_ptr impl; -}; - -template -void launch_cpu_kernel(F kernel, Args... args) -{ - kernel(args...); -} - -template -float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args) -{ - WallTimer timer; - - int nwarmup = 3; - - for(int i = 0; i < nwarmup; i++) - kernel(args...); - - timer.Start(); - for(int i = 0; i < nrepeat; i++) - { - kernel(args...); - } - timer.End(); - - return timer.GetElapsedTime() / nrepeat; -} diff --git a/library/include/ck/library/host_tensor/device_memory.hpp b/library/include/ck/library/host_tensor/device_memory.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5667db7fc775115794da2fb92c735e496f66f836 --- /dev/null +++ b/library/include/ck/library/host_tensor/device_memory.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + std::size_t GetBufferSize(); + void ToDevice(const void* p); + void FromDevice(void* p); + void SetZero(); + template + void SetValue(T x) + { + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; diff --git a/library/include/ck/library/host_tensor/device_tensor.hpp b/library/include/ck/library/host_tensor/device_tensor.hpp deleted file mode 100644 index b8d3ccc8a0ba0fb37a2a4d6ac4a22dfc706f3f52..0000000000000000000000000000000000000000 --- a/library/include/ck/library/host_tensor/device_tensor.hpp +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) -{ - ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os); -} diff --git a/library/include/ck/library/host_tensor/host_common_util.hpp b/library/include/ck/library/host_tensor/host_common_util.hpp index 8fc1d3643045ad8d045cb8a2d564db20caa75adf..31e5571eede5bc99ac8c20ccaf403e122ef986ac 100644 --- a/library/include/ck/library/host_tensor/host_common_util.hpp +++ b/library/include/ck/library/host_tensor/host_common_util.hpp @@ -1,37 +1,14 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef GUARD_HOST_COMMON_UTIL_HPP -#define GUARD_HOST_COMMON_UTIL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include #include #include -#include "config.hpp" +#include "ck/ck.hpp" namespace ck { @@ -95,8 +72,5 @@ static inline std::vector getTypeValuesFromString(const char* cstr_values) return (values); } -}; // namespace host_common - -}; // namespace ck - -#endif +} // namespace host_common +} // namespace ck diff --git a/library/include/ck/library/host_tensor/host_conv.hpp b/library/include/ck/library/host_tensor/host_conv.hpp index 3d2588c08b417f87ddef10a72a5ea1b12379bb03..8348a3089f4606563caa5582ac427154308e7fa5 100644 --- a/library/include/ck/library/host_tensor/host_conv.hpp +++ b/library/include/ck/library/host_tensor/host_conv.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "host_tensor.hpp" #include "conv_common.hpp" diff --git a/library/include/ck/library/host_tensor/host_gemm.hpp b/library/include/ck/library/host_tensor/host_gemm.hpp index 211c01c01a7ba14ac5091312a54fbfbcc87c0aa4..44036d0234375824ed3aa773ffe71de02b84c570 100644 --- a/library/include/ck/library/host_tensor/host_gemm.hpp +++ b/library/include/ck/library/host_tensor/host_gemm.hpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include "host_tensor.hpp" template #include #include -#include "reduction_enums.hpp" -#include "reduction_common.hpp" -#include "host_common_util.hpp" -#include "host_tensor.hpp" -#include "data_type.hpp" -#include "reduction_functions_accumulate.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/library/host_tensor/host_common_util.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" template static void get_all_indexes(const std::array& dimLengths, @@ -174,15 +150,18 @@ struct ReductionHost const InDataType* in_data, float beta, OutDataType* out_data, - IndexDataType* out_indices) + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) { if constexpr(OutputIndex) { - RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); + RunImpl_with_index( + alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op); } else { - RunImpl_no_index(alpha, in_data, beta, out_data); + RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op); }; }; @@ -190,7 +169,9 @@ struct ReductionHost const InDataType* in_data, float beta, OutDataType* out_data, - IndexDataType* out_indices) + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) { using ck::float_equal_one; using ck::float_equal_zero; @@ -200,12 +181,10 @@ struct ReductionHost ReduceOperation, AccDataType, IndexDataType>; - InElementwiseOperation in_elementwise_op(divider); - AccElementwiseOperation acc_elementwise_op(divider); if constexpr(NumInvariantDim == 0) { - AccDataType accuVal = ReduceOperation::GetIdentityValue(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) @@ -236,7 +215,7 @@ struct ReductionHost else { auto thread_reduce_func = [&](auto invariant_index) { - AccDataType accuVal = ReduceOperation::GetIdentityValue(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; auto offset_invariant = @@ -297,7 +276,12 @@ struct ReductionHost }; }; - void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data) + void RunImpl_no_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) { using ck::float_equal_one; using ck::float_equal_zero; @@ -306,12 +290,9 @@ struct ReductionHost using Accumulation = ck::detail::AccumulateWithNanCheck; - InElementwiseOperation in_elementwise_op(divider); - AccElementwiseOperation acc_elementwise_op(divider); - if constexpr(NumInvariantDim == 0) { - AccDataType accuVal = ReduceOperation::GetIdentityValue(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); for(const auto& reduce_index : reduce_dim_indexes) { @@ -338,7 +319,7 @@ struct ReductionHost else { auto thread_reduce_func = [&](auto invariant_index) { - AccDataType accuVal = ReduceOperation::GetIdentityValue(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); auto offset_invariant = get_offset_from_index(invariantStrides, invariant_index); @@ -395,5 +376,3 @@ struct ReductionHost }; }; }; - -#endif diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 5aa1493786895bf6df30fc7e3f9ce31897c52283..4c20d8078bb0e2048128db28b8de517033f29b42 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -1,5 +1,7 @@ -#ifndef HOST_TENSOR_HPP -#define HOST_TENSOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include @@ -8,7 +10,7 @@ #include #include #include -#include "common_header.hpp" +#include "ck/utility/data_type.hpp" template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) @@ -107,6 +109,11 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } + std::size_t GetOffsetFromMultiIndex(std::vector iss) const + { + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); private: @@ -212,6 +219,72 @@ struct Tensor Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} + template + Tensor CopyAsType() + { + Tensor ret(mDesc); + for(size_t i = 0; i < mData.size(); i++) + { + ret.mData[i] = static_cast(mData[i]); + } + return ret; + } + + Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} + + Tensor& operator=(const Tensor& other) + { + mDesc = other.mDesc; + mData = other.mData; + return *this; + } + + template + void ForEach_impl(F&& f, std::vector& idx, size_t rank) + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(F&& f) + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(const F&& f) const + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + template void GenerateTensorValue(G g, std::size_t num_thread = 1) { @@ -272,6 +345,16 @@ struct Tensor return mData[mDesc.GetOffsetFromMultiIndex(is...)]; } + T& operator()(std::vector idx) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + const T& operator()(std::vector idx) const + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + typename std::vector::iterator begin() { return mData.begin(); } typename std::vector::iterator end() { return mData.end(); } @@ -285,7 +368,8 @@ struct Tensor }; template -HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) : mLens(lens) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) + : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } @@ -293,17 +377,12 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) : mLens(l template HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens, const std::vector& strides) - : mLens(lens), mStrides(strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); - #if 1 // FIXME: remove -void bf16_to_f32_(const Tensor& src, Tensor& dst); -#endif - template float check_error(const Tensor& ref, const Tensor& result) { @@ -349,5 +428,4 @@ float check_error(const Tensor& ref, const Tensor& result) return linf_error; } - #endif diff --git a/library/include/ck/library/host_tensor/host_tensor_generator.hpp b/library/include/ck/library/host_tensor/host_tensor_generator.hpp index 17e20351f04dedb7488e93327e699ab9d9952b6d..e0bd4991ef988705cd5f36c919f047dc6592ba04 100644 --- a/library/include/ck/library/host_tensor/host_tensor_generator.hpp +++ b/library/include/ck/library/host_tensor/host_tensor_generator.hpp @@ -1,9 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include -#include "config.hpp" +#include "ck/ck.hpp" template struct GeneratorTensor_0 @@ -18,12 +21,12 @@ struct GeneratorTensor_0 template struct GeneratorTensor_1 { - int value = 1; + T value = 1; template T operator()(Is...) { - return ck::type_convert(value); + return value; } }; diff --git a/library/include/ck/library/obselete_driver_offline/debug.hpp b/library/include/ck/library/obselete_driver_offline/debug.hpp deleted file mode 100644 index 72fd0763ba5781609e22591917cc24cede8fd677..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/debug.hpp +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef DEBUG_HPP -#define DEBUG_HPP - -namespace debug { -namespace debug_driver_gemm_xdlops_v2r3 { - -// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping -static ck::index_t M01 = 1; -static ck::index_t N01 = 1; - -} // namespace debug_driver_gemm_xdlops_v2r3 -} // namespace debug -#endif diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp deleted file mode 100644 index debb5058e7247923bfac8d5dab721c5f06082297..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp +++ /dev/null @@ -1,220 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" - -template -void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( - const InLengths& in_n_c0_hi_wi_c1_lengths, - const WeiLengths& wei_k_c0_y_x_c1_lengths, - const AddLengths& add_n_k0_hox2_wox2_k1_lengths, - const OutLengths& out_n_k0_ho_wo_k1_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c0_hi_wi_c1, - const Tensor& wei_k_c0_y_x_c1, - const Tensor& bias_k0_k1, - const Tensor& add_n_k0_hox2_wox2_k1, - Tensor& add_n_k0_hox2_wox2_k1_out, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = out_n_k0_ho_wo_k1_lengths[I0]; - const auto K0 = out_n_k0_ho_wo_k1_lengths[I1]; - const auto Ho = out_n_k0_ho_wo_k1_lengths[I2]; - const auto Wo = out_n_k0_ho_wo_k1_lengths[I3]; - const auto K1 = out_n_k0_ho_wo_k1_lengths[I4]; - - const auto C0 = in_n_c0_hi_wi_c1_lengths[I1]; - const auto Hi = in_n_c0_hi_wi_c1_lengths[I2]; - const auto Wi = in_n_c0_hi_wi_c1_lengths[I3]; - const auto C1 = in_n_c0_hi_wi_c1_lengths[I4]; - - const auto K = wei_k_c0_y_x_c1_lengths[I0]; - const auto Y = wei_k_c0_y_x_c1_lengths[I2]; - const auto X = wei_k_c0_y_x_c1_lengths[I3]; - - const auto Hox2 = add_n_k0_hox2_wox2_k1_lengths[I2]; - const auto Wox2 = add_n_k0_hox2_wox2_k1_lengths[I3]; - - DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * - in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); - DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); - DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); - DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) * - add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace()); - - in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); - wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); - bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data()); - add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data()); - - constexpr index_t InWeiVectorSize = 8; - - if(C1 % InWeiVectorSize != 0) - { - throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); - } - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 32; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 64; - - constexpr index_t E1 = C0 * 9; - constexpr index_t E2 = 1; - constexpr index_t E1PerBlock = C0; - - constexpr index_t KPerThread = 16; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = 1; - - using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>; - using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; - constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; - - constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; - - constexpr index_t CThreadTransferDstScalarPerVector_K = K1; -#elif 1 - constexpr auto BlockSize = 64; - - constexpr auto KPerBlock = 8; - constexpr auto HoPerBlock = 8; - constexpr auto WoPerBlock = 32; - - constexpr auto E1 = 2 * 9; - constexpr auto E2 = 1; - constexpr auto K2 = 2; - constexpr auto E1PerBlock = 2; - - constexpr auto KPerThread = KPerBlock; - constexpr auto HoPerThread = 2; - constexpr auto WoPerThread = 2; - constexpr auto EPerThread = 1; - - using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>; - using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = - Sequence<1, E1PerBlock, 1, KPerBlock, 1>; - - constexpr auto ABlockTransferSrcScalarPerVector_E2 = E2; - constexpr auto ABlockTransferDstScalarPerVector_E2 = E2; - constexpr auto BThreadTransferSrcScalarPerVector_E2 = E2; - constexpr auto CThreadTransferDstScalarPerVector_K = InWeiVectorSize; -#endif - - const auto in_n_c0_hi_wi_c1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)); - const auto wei_k_c0_y_x_c1_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2)); - const auto add_n_k0_hox2_wox2_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)); - const auto out_n_k0_ho_wo_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); - - constexpr auto conv_driver = - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add< - BlockSize, - typename vector_type::type, - TAcc, - TOut, - E1, - E2, - K2, - KPerBlock, - HoPerBlock, - WoPerBlock, - E1PerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - ABlockTransferSrcScalarPerVector_E2, - ABlockTransferDstScalarPerVector_E2, - BThreadTransferSrcScalarPerVector_E2, - CThreadTransferDstScalarPerVector_K, - activ_type>{}; - - std::cerr << "conv_bias_activ_resize_add_input_" - << "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K - << "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_addout_n" << N << "k" << K0 - << "h" << Ho * 2 << "w" << Wo * 2 << "k" << K1 << std::endl; - - for(int i = 0; i < 5; i++) - { - - const auto ave_time = - conv_driver.Run(wei_k_c0_y_x_c1_desc, - in_n_c0_hi_wi_c1_desc, - out_n_k0_ho_wo_k1_desc, - add_n_k0_hox2_wox2_k1_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), - static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), - static_cast(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()), - nrepeat); - - { - float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data()); - - conv_driver.Run(wei_k_c0_y_x_c1_desc, - in_n_c0_hi_wi_c1_desc, - out_n_k0_ho_wo_k1_desc, - add_n_k0_hox2_wox2_k1_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), - static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), - static_cast(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()), - 0); - - add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 79d31ba2467c1d59e3083e0f2225f69810fd93cc..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,309 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" -#include "debug.hpp" - -template -void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - I0, - I0, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - // clang-format off - constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - //clang-format on - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(out_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<1, 3, 7, 0, 2, 4, 5, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - false, // ABlockLdsExtraM - false // BBlockLdsExtraN - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index e3b6a6c8c29729295144ef83318a4f3faf257144..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,423 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: - // gemmk1 - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - - // clang-format off - constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - // clang-format on - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - float ave_time = 0; - - for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) - { - for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) - { - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( - out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - i_ytilde, - i_xtilde, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - - const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0); - - if(GemmK0 != 0) - { - ave_time += driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy -#if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, -#else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, -#endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true, // CAccessOrderMRepeatNRepeat - false, // ABlockLdsExtraM - false // BBlockLdsExtraN - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - } - } - } - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp deleted file mode 100644 index 9cc4052f7782ec3fdc1d17df7e6b83a1498402d1..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp +++ /dev/null @@ -1,389 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations&, - const InLeftPads&, - const InRightPads&, - Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0>{}, // 1+: gemmm - Sequence<0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0>{}, // 1-: gemmm - Sequence<0, 0, 0>{})); // 2-: gemmk1 - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0>{}, // 1+: gemmn - Sequence<0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0>{}, // 1-: Gemmn - Sequence<0, 0, 0>{})); // 2-: Gemmk1 - - // clang-format off - constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - // clang-format on - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( - out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy -#if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, -#else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, -#endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true, // CAccessOrderMRepeatNRepeat - false, // ABlockLdsExtraM - false // BBlockLdsExtraN - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 993630f3f8a035f3f990c9a132bf7f58cd83be4e..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,256 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_xdlops_v2r4.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - Tensor& wei_k_c_y_x, - const Tensor& out_n_k_ho_wo, - GridSizeType desired_grid_size, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>; - // using vector load 4, so config's wo*ho must be a multiple of 4 - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto N = in_n_c_hi_wi_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_desc.GetLength(I1); - - const auto Ho = out_n_k_ho_wo_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_desc.GetLength(I2); - const auto X = wei_k_c_y_x_desc.GetLength(I3); - - const auto GemmM = K; - const auto GemmN = Y * X * C; - const auto GemmKTotal = N * Ho * Wo; - - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); - const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; - - std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN - << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad - << std::endl; - const auto descs = - transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - GemmKBatch, - GemmKPad); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB - Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM - Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB - Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM - Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 1, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; - - const auto driver_gemm_xdlops = - driver_gemm_xdlops_v2r4, - Sequence<0, 2, 1, 3>, - 3, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1, 3>, - Sequence<0, 2, 1, 3>, - 3, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, - true, - true>; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - 0); - // copy result back to host - wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index dfb612f690e5f01f3fb18301b1c12d498bfb6716..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,234 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - Tensor& wei_k_c_y_x, - const Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - // using vector load 4, so config's wo*ho must be a multiple of 4 - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - // using vector load 4, so config's wo*ho must be a multiple of 4 - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 1, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 06d0ea684f9824afa19b1e22e8b041be14a3755a..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,288 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r4.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - GridSizeType desired_grid_size, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto N = in_n_hi_wi_c_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_desc.GetLength(I1); - const auto X = wei_k_y_x_c_desc.GetLength(I2); - - const auto GemmM = Y * X * C; - const auto GemmN = K; - const auto GemmKTotal = N * Ho * Wo; - - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); - const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; - - std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN - << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad - << std::endl; - - const auto descs = - transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - GemmKBatch, - GemmKPad); - - const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; - - constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum::AtomicAdd, - decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), - decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, - true>; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - 0); - // copy result back to host - wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 5221ec582d2d12466e6449301a758b921afa27f4..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,276 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" -#include "debug.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; - -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum::Set, - decltype(in_gemmk0_gemmm_gemmk1_grid_desc), - decltype(out_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, - true>(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - in_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 1bdad6e97b307667eac37cabef04bb10ea08431f..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,456 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r4.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - GridSizeType desired_grid_size, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto N = in_n_hi_wi_c_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_desc.GetLength(I1); - const auto X = wei_k_y_x_c_desc.GetLength(I2); - - const auto GemmM = K; - const auto GemmN = Y * X * C; - const auto GemmKTotal = N * Ho * Wo; - - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); - const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; - - std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN - << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad - << std::endl; - - const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - GemmKBatch, - GemmKPad); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; - - const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum::AtomicAdd, - decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 3, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, - true>; - - // timing - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // verification - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - 0); - // copy result back to host - wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index a9df58bedda4ebb0ff0b84c9b5548ef6e174e0f8..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,201 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_dlops_v1r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 1 - // cdata = 64, BlockSize = 256, 128x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); - - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - const auto wei_gemmk_gemmm_grid_desc = descs[I0]; - const auto in_gemmk_gemmn_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_dlops_v1r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(wei_gemmk_gemmm_grid_desc), - decltype(in_gemmk_gemmn_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlockM1, - GemmNPerBlockN1, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM1100, - GemmM11N11ThreadClusterN1100, - GemmM11N11ThreadClusterM1101, - GemmM11N11ThreadClusterN1101, - GemmABlockTransferThreadSliceLengths_K_M0_M1, - GemmABlockTransferThreadClusterLengths_K_M0_M1, - Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder - Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder - 0, // ABlockTransferSrcVectorDim - GemmABlockTransferSrcScalarPerVector_K, - GemmABlockTransferDstScalarPerVector_M1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_K_N0_N1, - GemmBBlockTransferThreadClusterLengths_K_N0_N1, - Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder - Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - GemmBBlockTransferSrcScalarPerVector_N1, - GemmBBlockTransferDstScalarPerVector_N1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_N11, - decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), - decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gemmk_gemmm_grid_desc, - in_gemmk_gemmn_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk_gemmm0_gemmn1_grid_step_hacks, - in_gemmk_gemmn0_gemmn1_grid_step_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, - wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks, - in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 843df27a88a774d44ef8917214cb59d387a9ba21..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,273 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_dlops_v1r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmK1 = 1; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; - using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; - - using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; - - using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; - using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>; - - using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; - - using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; - using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 8, 2] for fp16 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmK1 = 2; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; - using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; - - using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; - using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; - - using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; - using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>; - - using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; - using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; - - using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; - using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 8, 4] for i8 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmK1 = 4; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; - using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; - - using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; - using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; - - using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; - using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>; - - using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; - using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; - - using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; - using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 - Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 - Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0 - Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10 - Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10 - Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11 - Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0 - Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 - Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 - - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; - - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_dlops_v1r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(in_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlockM1, - GemmNPerBlockN1, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM110Xs, - GemmM11N11ThreadClusterN110Xs, - GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, - GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, - Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder - Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder - GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder - GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1, - GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1, - Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder - Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder - GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder - GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_N11, - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>( - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index e4cf4dd25cd88022404311b976aa8f268d18b632..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,228 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 18e712fb47c0ab5944a3adeaeb629b9d44398a0a..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,600 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -#if 0 -__host__ __device__ static constexpr auto -MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1, - const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1, - const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw) -{ - const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0); - const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2); - const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0); - const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1); - - const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw; - const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw; - const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw; - - // A - const auto a_grid_desc_k0_m_k1 = [&]() { - if constexpr(DoPad_K0 && DoPad_M) - { - return transform_tensor_descriptor( - a_grid_desc_k0_m_k1, - make_tuple(make_right_pad_transform(K0Raw, K0Pad), - make_right_pad_transform(MRaw, MPad), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - } - else if constexpr(DoPad_K0 && !DoPad_M) - { - return transform_tensor_descriptor( - a_grid_desc_k0_m_k1, - make_tuple(make_right_pad_transform(K0Raw, K0Pad), - make_pass_through_transform(MRaw), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - } - else if constexpr(!DoPad_K0 && DoPad_M) - { - return transform_tensor_descriptor( - a_grid_desc_k0_m_k1, - make_tuple(make_pass_through_transform(K0Raw), - make_right_pad_transform(MRaw, MPad), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - } - else - { - return a_grid_desc_k0raw_mraw_k1; - } - }(); - - // B - const auto b_grid_desc_k0_n_k1 = [&]() { - if constexpr(DoPad_K0 && DoPad_N) - { - return transform_tensor_descriptor( - b_grid_desc_k0_n_k1, - make_tuple(make_right_pad_transform(K0Raw, K0Pad), - make_right_pad_transform(NRaw, NPad), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - } - else if constexpr(DoPad_K0 && !DoPad_N) - { - return transform_tensor_descriptor( - b_grid_desc_k0_n_k1, - make_tuple(make_right_pad_transform(K0Raw, K0Pad), - make_pass_through_transform(NRaw), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - } - else if constexpr(!DoPad_K0 && DoPad_N) - { - return transform_tensor_descriptor( - b_grid_desc_k0_n_k1, - make_tuple(make_pass_through_transform(K0Raw), - make_right_pad_transform(NRaw, NPad), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - } - else - { - return b_grid_desc_k0raw_nraw_k1; - } - }(); - - // C - const auto c_grid_desc_m_n = [&]() { - if constexpr(DoPad_M && DoPad_N) - { - return transform_tensor_descriptor(c_grid_desc_m_n, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(DoPad_M && !DoPad_N) - { - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(!DoPad_M && DoPad_N) - { - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - reutnr c_grid_desc_m_n; - } - }(); -} -#endif - -template -void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - -#if 0 // debug - const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - - // HACK: hacks that control index calculation when iterating over A matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; -#else - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; - - const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0); - const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1); - const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw; - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmK1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // HACK: hacks that control index calculation when iterating over A matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; -#endif - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - - const auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - -#if 0 - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 -#else - const auto out_gemmmraw_gemmn_grid_desc = descs[I2]; - - const auto GemmN = out_gemmmraw_gemmn_grid_desc.GetLength(I1); - - const auto out_gemmm_gemmn_grid_desc = - transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 -#endif - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(in_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp deleted file mode 100644 index af4676f2a242f9733fa05b8adf82558573d51607..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp +++ /dev/null @@ -1,196 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" - -template -void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( - const InLengths& in_n_c0_hi_wi_c1_lengths, - const WeiLengths& wei_k_c0_y_x_c1_lengths, - const OutLengths& out_n_k0_ho_wo_k1_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c0_hi_wi_c1, - const Tensor& wei_k_c0_y_x_c1, - const Tensor& bias_k0_k1, - Tensor& out_n_k0_ho_wo_k1, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = out_n_k0_ho_wo_k1_lengths[I0]; - const auto K0 = out_n_k0_ho_wo_k1_lengths[I1]; - const auto Ho = out_n_k0_ho_wo_k1_lengths[I2]; - const auto Wo = out_n_k0_ho_wo_k1_lengths[I3]; - const auto K1 = out_n_k0_ho_wo_k1_lengths[I4]; - - const auto C0 = in_n_c0_hi_wi_c1_lengths[I1]; - const auto Hi = in_n_c0_hi_wi_c1_lengths[I2]; - const auto Wi = in_n_c0_hi_wi_c1_lengths[I3]; - const auto C1 = in_n_c0_hi_wi_c1_lengths[I4]; - - const auto K = wei_k_c0_y_x_c1_lengths[I0]; - const auto Y = wei_k_c0_y_x_c1_lengths[I2]; - const auto X = wei_k_c0_y_x_c1_lengths[I3]; - - DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * - in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); - DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); - DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); - DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * - out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); - in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); - wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); - bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data()); - - constexpr index_t InWeiVectorSize = 8; - - if(C1 % InWeiVectorSize != 0) - { - throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); - } - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 32; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 64; - - constexpr index_t E1 = C0 * 9; - constexpr index_t E2 = 1; - constexpr index_t E1PerBlock = C0; - - constexpr index_t KPerThread = 16; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = 1; - - using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>; - using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; - constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; - - constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; - - constexpr index_t CThreadTransferDstScalarPerVector_K = K1; -#elif 1 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 8; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 32; - - constexpr index_t E1 = 2 * 9; - constexpr index_t E2 = 1; - constexpr index_t K2 = 2; - constexpr index_t E1PerBlock = 2; - - constexpr index_t KPerThread = KPerBlock; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = 1; - - using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>; - using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = - Sequence<1, E1PerBlock, 1, KPerBlock, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; - constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; - constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; - constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize; -#endif - - if(KPerThread % InWeiVectorSize != 0) - { - throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); - } - - const auto in_n_c0_hi_wi_c1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)); - const auto wei_k_c0_y_x_c1_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2)); - const auto out_n_k0_ho_wo_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); - - constexpr auto conv_driver = - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad< - BlockSize, - typename vector_type::type, - TAcc, - TOut, - E1, - E2, - K2, - KPerBlock, - HoPerBlock, - WoPerBlock, - E1PerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - ABlockTransferSrcScalarPerVector_E2, - ABlockTransferDstScalarPerVector_E2, - BThreadTransferSrcScalarPerVector_E2, - CThreadTransferDstScalarPerVector_K, - activ_type>{}; - - std::cerr << "conv_bias_activ_input_" - << "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K - << "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0 - << "h" << Ho << "w" << Wo << "k" << K1 << std::endl; - - for(int i = 0; i < 5; i++) - { - - const auto ave_time = - conv_driver.Run(wei_k_c0_y_x_c1_desc, - in_n_c0_hi_wi_c1_desc, - out_n_k0_ho_wo_k1_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), - static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), - static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()), - nrepeat); - - { - float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 31925f0511cd60155ec7dcc0a9989399f1a0d67c..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,241 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "driver_contraction_dlops_v1r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 1 - // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GN0 = 4; - constexpr index_t GK1 = 1; - - constexpr index_t GM1PerBlockGM11 = 128; - constexpr index_t GN1PerBlockGN11 = 32; - constexpr index_t GK0PerBlock = 8; - - constexpr index_t BM1PerThreadBM11 = 4; - constexpr index_t BN1PerThreadBN11 = 4; - constexpr index_t BK0PerThread = 1; - - using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; - using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; - - using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; - - using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; - - using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; - using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; - - using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - - constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; -#elif 1 - // [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GN0 = 4; - constexpr index_t GK1 = 2; - - constexpr index_t GM1PerBlockGM11 = 128; - constexpr index_t GN1PerBlockGN11 = 32; - constexpr index_t GK0PerBlock = 8; - - constexpr index_t BM1PerThreadBM11 = 4; - constexpr index_t BN1PerThreadBN11 = 4; - constexpr index_t BK0PerThread = 1; - - using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; - using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; - - using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; - using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; - - using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; - - using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; - using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; - - using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; - - constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; -#endif - - const auto descs = - transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x, - in_desc_n_c_hi_wi, - out_desc_n_k_ho_wo, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - Number{}); - - const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; - const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - - constexpr auto in_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - - constexpr auto out_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 - - constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; - - constexpr auto in_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_contraction_dlops_v1r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum::Set, - decltype(wei_grid_desc_gk0_gm0_gm1_gk1), - decltype(in_grid_desc_gk0_gn0_gn1_gk1), - decltype(out_grid_desc_gm0_gm1_gn0_gn1), - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder - Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder - Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder - Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - CThreadTransferDstScalarPerVector_BN1, - decltype(wei_grid_step_hacks), - decltype(in_grid_step_hacks), - decltype(out_grid_step_hacks), - decltype(wei_grid_move_slice_window_step_hacks), - decltype(in_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_grid_desc_gk0_gm0_gm1_gk1, - in_grid_desc_gk0_gn0_gn1_gk1, - out_grid_desc_gm0_gm1_gn0_gn1, - wei_grid_step_hacks, - in_grid_step_hacks, - out_grid_step_hacks, - wei_grid_move_slice_window_step_hacks, - in_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp deleted file mode 100644 index 2cb2e1091528e52b681a8dbfa6ee6ae87c53f7a2..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp +++ /dev/null @@ -1,212 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" - -template -void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( - const InLengths& in_n_c0_hi_wi_c1_lengths, - const WeiLengths& wei_k_c0_y_x_c1_lengths, - const MaxLengths& max_n_k0_hx_wx_k1_lengths, - const OutLengths& out_n_k0_ho_wo_k1_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c0_hi_wi_c1, - const Tensor& wei_k_c0_y_x_c1, - const Tensor& bias_k0_k1, - Tensor& out_n_k0_ho_wo_k1, - Tensor& max_n_k0_hx_wx_k1, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = out_n_k0_ho_wo_k1_lengths[I0]; - const auto K0 = out_n_k0_ho_wo_k1_lengths[I1]; - const auto Ho = out_n_k0_ho_wo_k1_lengths[I2]; - const auto Wo = out_n_k0_ho_wo_k1_lengths[I3]; - const auto K1 = out_n_k0_ho_wo_k1_lengths[I4]; - - const auto C0 = in_n_c0_hi_wi_c1_lengths[I1]; - const auto Hi = in_n_c0_hi_wi_c1_lengths[I2]; - const auto Wi = in_n_c0_hi_wi_c1_lengths[I3]; - const auto C1 = in_n_c0_hi_wi_c1_lengths[I4]; - - const auto K = wei_k_c0_y_x_c1_lengths[I0]; - const auto Y = wei_k_c0_y_x_c1_lengths[I2]; - const auto X = wei_k_c0_y_x_c1_lengths[I3]; - - const auto Hx = max_n_k0_hx_wx_k1_lengths[I2]; - const auto Wx = max_n_k0_hx_wx_k1_lengths[I3]; - - DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * - in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); - DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); - DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); - DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * - out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); - DeviceMem max_n_k0_hx_wx_k1_device_buf(sizeof(TOut) * - max_n_k0_hx_wx_k1.mDesc.GetElementSpace()); - - in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); - wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); - bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data()); - max_n_k0_hx_wx_k1_device_buf.ToDevice(max_n_k0_hx_wx_k1.mData.data()); - - constexpr index_t InWeiVectorSize = 8; - - if(C1 % InWeiVectorSize != 0) - { - throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); - } - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 32; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 64; - - constexpr index_t E1 = C0 * 9; - constexpr index_t E2 = 1; - constexpr index_t E1PerBlock = C0; - - constexpr index_t KPerThread = 16; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = 1; - - using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>; - using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; - constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; - - constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; - - constexpr index_t CThreadTransferDstScalarPerVector_K = K1; -#elif 1 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 8; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 32; - - constexpr index_t E1 = 2 * 9; - constexpr index_t E2 = 1; - constexpr index_t K2 = 2; - constexpr index_t E1PerBlock = 2; - - constexpr index_t KPerThread = KPerBlock; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = 1; - - using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>; - using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = - Sequence<1, E1PerBlock, 1, KPerBlock, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; - constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; - constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; - constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize; -#endif - - if(KPerThread % InWeiVectorSize != 0) - { - throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); - } - - const auto in_n_c0_hi_wi_c1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)); - const auto wei_k_c0_y_x_c1_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2)); - const auto max_n_k0_hx_wx_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)); - const auto out_n_k0_ho_wo_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); - - constexpr auto conv_driver = - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool< - BlockSize, - typename vector_type::type, - TAcc, - TOut, - E1, - E2, - K2, - KPerBlock, - HoPerBlock, - WoPerBlock, - E1PerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - ABlockTransferSrcScalarPerVector_E2, - ABlockTransferDstScalarPerVector_E2, - BThreadTransferSrcScalarPerVector_E2, - CThreadTransferDstScalarPerVector_K, - activ_type>{}; - - std::cerr << "conv_bias_activ_maxpool_input_" - << "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K - << "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0 - << "h" << Ho << "w" << Wo << "k" << K1 << "_maxpoolout_n" << N << "k" << K0 << "h" - << Ho / 2 << "w" << Wo / 2 << "k" << K1 << std::endl; - - for(int i = 0; i < 5; i++) - { - - const auto ave_time = - conv_driver.Run(wei_k_c0_y_x_c1_desc, - in_n_c0_hi_wi_c1_desc, - out_n_k0_ho_wo_k1_desc, - max_n_k0_hx_wx_k1_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), - static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), - static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()), - static_cast(max_n_k0_hx_wx_k1_device_buf.GetDeviceBuffer()), - nrepeat); - - { - float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); - max_n_k0_hx_wx_k1_device_buf.FromDevice(max_n_k0_hx_wx_k1.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_mn.hpp deleted file mode 100644 index f54ff181dd94d10bd3a3aad975f39ac60d48cedc..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_mn.hpp +++ /dev/null @@ -1,463 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_kn_mn(const Tensor& a_k_m, - const Tensor& b_k_n, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_nm.hpp deleted file mode 100644 index eb78ba96d8b3afec2f86dca6848c4b4c1c7bae98..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_nm.hpp +++ /dev/null @@ -1,263 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_kn_nm(const Tensor& a_k_m, - const Tensor& b_k_n, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_mn.hpp deleted file mode 100644 index dbd318ce4dc3f2541120a39f6c4bfcbe0e887786..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_mn.hpp +++ /dev/null @@ -1,463 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_nk_mn(const Tensor& a_k_m, - const Tensor& b_n_k, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_nm.hpp deleted file mode 100644 index 5b819fd1af44a1ffebef83a9b30d917517805c42..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_nm.hpp +++ /dev/null @@ -1,263 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_nk_nm(const Tensor& a_k_m, - const Tensor& b_n_k, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp deleted file mode 100644 index 4b041777c3ee00adba32dcc14451cc6d510534e5..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp +++ /dev/null @@ -1,463 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp deleted file mode 100644 index c848cd793615656ebd6a749739bd83abd4a5ec36..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp +++ /dev/null @@ -1,291 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_kn_nm(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp deleted file mode 100644 index 557624026d5e05531b2b27fe3604095e27d48677..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp +++ /dev/null @@ -1,564 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_nk_mn(const Tensor& a_m_k, - const Tensor& b_n_k, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - -#if 1 - // non-padded GEMM - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; -#else - // padded GEMM - const auto a_k0_m_k1_grid_desc_tmp = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M; - - const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp, - make_tuple(make_pass_through_transform(K0), - make_right_pad_transform(M, MRightPad), - make_pass_through_transform(K1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc_tmp, - make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0, 0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; -#endif - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp deleted file mode 100644 index 06d8ed29404c52f13aaea53cb1801f941b0fe0ab..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp +++ /dev/null @@ -1,347 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_nk_nm(const Tensor& a_m_k, - const Tensor& b_n_k, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp b/library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp deleted file mode 100644 index 000098f4fca0bc2c0e37ab156a277ee589859b71..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp +++ /dev/null @@ -1,286 +0,0 @@ -#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP -#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_contraction_dlops_v1r2.hpp" - -template -__host__ float -driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseContraction = - GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - CGlobalMemoryDataOperation, - AGridDesc_GK0_GM0_GM1_GK1, - BGridDesc_GK0_GN0_GN1_GK1, - CGridDesc_GM0_GM1_GN0_GN1, - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks>; - - const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - - if(!GridwiseContraction::CheckValidity( - a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) - { - throw std::runtime_error("wrong! " - "GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" - "GM0_GM1_GN0_GN1 has invalid setting"); - } - - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = - GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = - GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); - - using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1); - using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1); - - // c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1); - - using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); - - // c_grid_block_cluster_blockid_to_gm10_gn10 - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = - GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1); - - using CGridBlockCluster_BlockId_To_GM10_GN10 = - decltype(c_grid_block_cluster_blockid_to_gm10_gn10); - - const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1); - - const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); - - const bool has_double_tail_k_block_loop = - GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); - - { - std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{" - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl; - - std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{" - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl; - - std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl; - } - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - else - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - - return ave_time; -} -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp deleted file mode 100644 index ec16a97f6f68ad3112f8ef1adcdf13ed3c51c70b..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp +++ /dev/null @@ -1,429 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP -#define DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v3.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add -{ - template - __host__ float Run(const ck::TensorDescriptor& wei_k_c0_y_x_c1_global_desc, - const ck::TensorDescriptor& in_n_c0_hi_wi_c1_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ck::TensorDescriptor& add_n_k0_hox2_wox2_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_d_grid, - const int nrepeat) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0); - const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); - const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); - const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); - // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); - - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto Hox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I2); - const auto Wox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I3); - - const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0); - const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2); - const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; - const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; - - const auto OutRightPadH = Hop - Ho; - const auto OutRightPadW = Wop - Wo; - - const auto OutRightPadHx = OutRightPadH * 2; - const auto OutRightPadWx = OutRightPadW * 2; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; - const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; - - const auto E = C0 * Y * X; - - constexpr auto E1 = Number{}; - constexpr auto E2 = Number{}; - constexpr auto K2 = Number{}; - - const auto E0 = E / E1; - - // weight tensor - const auto a_e_k_e2_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)), - make_tuple(make_pass_through_transform(K), - make_pass_through_transform(C0 * Y * X), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); - - const auto a_e0_e1_k_e2_grid_desc = - transform_tensor_descriptor(a_e_k_e2_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(E0, E1)), - make_pass_through_transform(K), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); - - // input tensor - const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)), - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C0), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor( - in_n_c0_hip_wip_e2_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C0), - make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); - - const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( - in_n_c0_y_ho_x_wo_e2_global_desc, - make_tuple(make_merge_transform(make_tuple(C0, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop), - make_pass_through_transform(E2)), - make_tuple( - Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( - in_e_n_ho_wo_e2_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(E0, E1)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); - - // output tensor - const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Ho, I0, OutRightPadH), - make_pad_transform(Wo, I0, OutRightPadW)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // add tensor - const auto d_k_n_hopx2_wopx2_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Hox2, I0, OutRightPadHx), - make_pad_transform(Wox2, I0, OutRightPadWx)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; - - if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && - (E1 % E1PerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // clang-format off - - // hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor - constexpr auto a_e0_e1_k_e2_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = - make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}) - ); - - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor - constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - // clang-format on - - // GEMM - using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum::Set, - decltype(a_e0_e1_k_e2_grid_desc), - decltype(b_e0_e1_n_ho_wo_e2_grid_desc), - decltype(c_k_n_hop_wop_grid_desc), - decltype(d_k_n_hopx2_wopx2_grid_desc), - E1, - E2, - K2, - KPerBlock, - HoPerBlock, - WoPerBlock, - E1PerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - Sequence<2, 3, 0, 1, 4>, - Sequence<0, 1, 2, 3, 4>, - 4, - ABlockTransferSrcScalarPerVector_E2, - ABlockTransferDstScalarPerVector_E2, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2 - 9, - BThreadTransferSrcScalarPerVector_E2, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2 - 1, - CThreadTransferDstScalarPerVector_K, - decltype(a_e0_e1_k_e2_global_step_hacks), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), - decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks), - decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>; - - const auto a_e0_e1_k0_k1_e2_grid_desc = - GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); - const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = - GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); - const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = - GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); - const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = - GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd( - d_k_n_hopx2_wopx2_grid_desc); - - using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); - using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); - using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = - decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); - - const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; - - const bool has_main_e0_block_loop = E0 > 1; - - std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; - - const auto cblockid_to_k_n_h_w_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); - - using CBlockIdToBlockClusterAdaptor_K_N_H_W = - decltype(cblockid_to_k_n_h_w_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_e0_block_loop) - { - const auto kernel = kernel_gemm_dlops_v3_resize_add< - GridwiseGemm, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - activ_type>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_bias_grid, - p_d_grid, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_gemm_dlops_v3_resize_add< - GridwiseGemm, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - activ_type>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_bias_grid, - p_d_grid, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor); - } - - return ave_time; - } -}; -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp deleted file mode 100644 index 34296405d4973eebe4cde767831ebc808c389409..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp +++ /dev/null @@ -1,386 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP -#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v3.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad -{ - template - __host__ float Run(const ck::TensorDescriptor& wei_k_c0_y_x_c1_global_desc, - const ck::TensorDescriptor& in_n_c0_hi_wi_c1_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_c_grid, - const int nrepeat) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0); - const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); - const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); - const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); - // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); - - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0); - const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2); - const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{}; - const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{}; -#else - const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; - const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; -#endif - - const auto OutRightPadH = Hop - Ho; - const auto OutRightPadW = Wop - Wo; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; - const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; - - const auto E = C0 * Y * X; - - constexpr auto E1 = Number{}; - constexpr auto E2 = Number{}; - constexpr auto K2 = Number{}; - - const auto E0 = E / E1; - - // weight tensor - const auto a_e_k_e2_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)), - make_tuple(make_pass_through_transform(K), - make_pass_through_transform(C0 * Y * X), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); - - const auto a_e0_e1_k_e2_grid_desc = - transform_tensor_descriptor(a_e_k_e2_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(E0, E1)), - make_pass_through_transform(K), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); - - // input tensor - const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)), - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C0), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor( - in_n_c0_hip_wip_e2_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C0), - make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); - - const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( - in_n_c0_y_ho_x_wo_e2_global_desc, - make_tuple(make_merge_transform(make_tuple(C0, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop), - make_pass_through_transform(E2)), - make_tuple( - Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( - in_e_n_ho_wo_e2_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(E0, E1)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); - - // output tensor - const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Ho, I0, OutRightPadH), - make_pad_transform(Wo, I0, OutRightPadW)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; - - if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && - (E1 % E1PerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // clang-format off - - // hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor - constexpr auto a_e0_e1_k_e2_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = - make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}) - ); - - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor - constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - // clang-format on - - // GEMM - using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum::Set, - decltype(a_e0_e1_k_e2_grid_desc), - decltype(b_e0_e1_n_ho_wo_e2_grid_desc), - decltype(c_k_n_hop_wop_grid_desc), - decltype(c_k_n_hop_wop_grid_desc), - E1, - E2, - K2, - KPerBlock, - HoPerBlock, - WoPerBlock, - E1PerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - Sequence<2, 3, 0, 1, 4>, - Sequence<0, 1, 2, 3, 4>, - 4, - ABlockTransferSrcScalarPerVector_E2, - ABlockTransferDstScalarPerVector_E2, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2 - 9, - BThreadTransferSrcScalarPerVector_E2, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2 - 1, - CThreadTransferDstScalarPerVector_K, - decltype(a_e0_e1_k_e2_global_step_hacks), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), - decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>; - - const auto a_e0_e1_k0_k1_e2_grid_desc = - GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); - const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = - GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); - const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = - GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); - - using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); - using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); - using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; - - const bool has_main_e0_block_loop = E0 > 1; - - std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; - - const auto cblockid_to_k_n_h_w_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); - - using CBlockIdToBlockClusterAdaptor_K_N_H_W = - decltype(cblockid_to_k_n_h_w_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_e0_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - activ_type>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_gemm_dlops_v3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - activ_type>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor); - } - - return ave_time; - } -}; -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp deleted file mode 100644 index 1b8e48e6c1e9879784520e8a3f2c97f4a5e09fda..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp +++ /dev/null @@ -1,440 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP -#define DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v3.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool -{ - template - __host__ float Run(const ck::TensorDescriptor& wei_k_c0_y_x_c1_global_desc, - const ck::TensorDescriptor& in_n_c0_hi_wi_c1_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ck::TensorDescriptor& max_n_k0_hx_wx_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_c_grid, - FloatC* __restrict__ p_d_grid, - const int nrepeat) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0); - const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); - const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); - const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); - // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); - - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto Hx = max_n_k0_hx_wx_k1_global_desc.GetLength(I2); - const auto Wx = max_n_k0_hx_wx_k1_global_desc.GetLength(I3); - - const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0); - const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2); - const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{}; - const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{}; - - const auto OutRightPadH = Hop - Ho; - const auto OutRightPadW = Wop - Wo; - - const auto OutRightPadHx = Number{}; - const auto OutRightPadWx = Number{}; -#else - const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; - const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; - - const auto OutRightPadH = Hop - Ho; - const auto OutRightPadW = Wop - Wo; - - const auto OutRightPadHx = OutRightPadH / 2; - const auto OutRightPadWx = OutRightPadW / 2; -#endif - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; - const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; - - const auto E = C0 * Y * X; - - constexpr auto E1 = Number{}; - constexpr auto E2 = Number{}; - constexpr auto K2 = Number{}; - - const auto E0 = E / E1; - - // weight tensor - const auto a_e_k_e2_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)), - make_tuple(make_pass_through_transform(K), - make_pass_through_transform(C0 * Y * X), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); - - const auto a_e0_e1_k_e2_grid_desc = - transform_tensor_descriptor(a_e_k_e2_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(E0, E1)), - make_pass_through_transform(K), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); - - // input tensor - const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)), - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C0), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor( - in_n_c0_hip_wip_e2_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C0), - make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); - - const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( - in_n_c0_y_ho_x_wo_e2_global_desc, - make_tuple(make_merge_transform(make_tuple(C0, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop), - make_pass_through_transform(E2)), - make_tuple( - Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( - in_e_n_ho_wo_e2_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(E0, E1)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); - - // output tensor - const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Ho, I0, OutRightPadH), - make_pad_transform(Wo, I0, OutRightPadW)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // max tensor - const auto d_k_n_hx_wx_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Hx, I0, OutRightPadHx), - make_pad_transform(Wx, I0, OutRightPadWx)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; - - if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && - (E1 % E1PerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // clang-format off - - // hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor - constexpr auto a_e0_e1_k_e2_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = - make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}) - ); - - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; - - constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - // clang-format on - - // GEMM - using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum::Set, - decltype(a_e0_e1_k_e2_grid_desc), - decltype(b_e0_e1_n_ho_wo_e2_grid_desc), - decltype(c_k_n_hop_wop_grid_desc), - decltype(d_k_n_hx_wx_grid_desc), - E1, - E2, - K2, - KPerBlock, - HoPerBlock, - WoPerBlock, - E1PerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - Sequence<2, 3, 0, 1, 4>, - Sequence<0, 1, 2, 3, 4>, - 4, - ABlockTransferSrcScalarPerVector_E2, - ABlockTransferDstScalarPerVector_E2, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2 - 9, - BThreadTransferSrcScalarPerVector_E2, - false, // don't move back src coordinate after threadwise copy, which will be fused - // with MoveSrcSliceWindow() to save addr computation - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2 - 1, - CThreadTransferDstScalarPerVector_K, - decltype(a_e0_e1_k_e2_global_step_hacks), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks), - decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>; - - const auto a_e0_e1_k0_k1_e2_grid_desc = - GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); - const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = - GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); - const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = - GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); - const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = - GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc); - - using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); - using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); - using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); - - const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; - - const bool has_main_e0_block_loop = E0 > 1; - - std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; - - const auto cblockid_to_k_n_h_w_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); - - using CBlockIdToBlockClusterAdaptor_K_N_H_W = - decltype(cblockid_to_k_n_h_w_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_e0_block_loop) - { - const auto kernel = kernel_gemm_dlops_v3_maxpool< - GridwiseGemm, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - activ_type>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - p_d_grid, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_gemm_dlops_v3_maxpool< - GridwiseGemm, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - activ_type>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - p_d_grid, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor); - } - - return ave_time; - } -}; -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp deleted file mode 100644 index ce0530b3fd28e0928b0c0e36143ab65f26bdf81f..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp +++ /dev/null @@ -1,278 +0,0 @@ -#ifndef DRIVER_GEMM_DLOPS_V1R2 -#define DRIVER_GEMM_DLOPS_V1R2 - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v1r2.hpp" - -template -__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AKMGridDesc& a_k_m_grid_desc, - const BKNGridDesc& b_k_n_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2; - - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting"); - } - - const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // cblockid_to_m0_n0_block_cluster_adaptor - const auto cblockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); - - { - std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " - << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " - << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -} -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp deleted file mode 100644 index 3fd1a1dbbac9735b4228719ce3c5d05c22fd5219..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp +++ /dev/null @@ -1,275 +0,0 @@ -#ifndef DRIVER_GEMM_DLOPS_V1R3 -#define DRIVER_GEMM_DLOPS_V1R3 - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v1r3.hpp" - -template -__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r3; - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting"); - } - - const auto a_k0_m0_m1_k1_grid_desc = - GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); - const auto b_k0_n0_n1_k1_grid_desc = - GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); - - using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); - using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // cblockid_to_m0_n0_block_cluster_adaptor - const auto cblockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); - - { - std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -} -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp deleted file mode 100644 index 5652040250e14febb82b22a8ccf5daa9de4ab21f..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp +++ /dev/null @@ -1,220 +0,0 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP -#define DRIVER_GEMM_XDLOPS_V2R3_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "element_wise_operation.hpp" - -template -__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K& b_grid_desc_k0_n_k1, - const CMNGridDesc& c_grid_desc_m_n, - ck::index_t M01, - ck::index_t N01, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - { - std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", " - << a_grid_desc_k0_m_k1.GetLength(I1) << ", " << a_grid_desc_k0_m_k1.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_grid_desc_k0_n_k1{" << b_grid_desc_k0_n_k1.GetLength(I0) << ", " - << b_grid_desc_k0_n_k1.GetLength(I1) << ", " << b_grid_desc_k0_n_k1.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_grid_desc_m_n{ " << c_grid_desc_m_n.GetLength(I0) << ", " - << c_grid_desc_m_n.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, M01, N01)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - - const auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); - - using Block2CTileMap = decltype(block_2_ctile_map); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - - auto element_op_ = ElementwiseOperation{}; - - if(has_main_k0_block_loop) - { - const auto kernel = - kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - ElementwiseOperation, - ElementwiseOperation, - ElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - element_op_, - element_op_, - element_op_, - block_2_ctile_map); - } - else - { - const auto kernel = - kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - ElementwiseOperation, - ElementwiseOperation, - ElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - element_op_, - element_op_, - element_op_, - block_2_ctile_map); - } - return ave_time; -} -#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp deleted file mode 100644 index 6e9983b0b50bde7a383842cf5bf87bd7d611472f..0000000000000000000000000000000000000000 --- a/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp +++ /dev/null @@ -1,213 +0,0 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R4 -#define DRIVER_GEMM_XDLOPS_V2R4 - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r4.hpp" - -template -__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, - const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - ck::index_t M01, - ck::index_t N01, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - using GridwiseGemm = - GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4; - - { - std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_b_k0_m_k1_grid_desc.GetLength(I1) << ", " - << a_b_k0_m_k1_grid_desc.GetLength(I2) << ", " - << a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_b_k0_n_k1_grid_desc.GetLength(I1) << ", " - << b_b_k0_n_k1_grid_desc.GetLength(I2) << ", " - << b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity( - a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"); - } - - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - - using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - - const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); - - const auto c_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch); - - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch); - { - std::cout << "gridSize : " << grid_size << std::endl; - } - - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } - - return ave_time; -} -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index f4944a28d2ec55e13bd42cbc86ff28a5427f7523..06e74a9e9aa92c4132c7c7ed28f2f7b0288d77fb 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -1,10 +1,13 @@ -#ifndef REFERENCE_BATCHED_GEMM_HPP -#define REFERENCE_BATCHED_GEMM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -59,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - float v_a; - float v_b; + ADataType v_a; + BDataType v_b; - arg.a_element_op_(v_a, static_cast(arg.a_g_m_k_(g, m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_g_k_n_(g, k, n))); + arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); + arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); - v_acc += v_a * v_b; + v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); } float v_c; arg.c_element_op_(v_c, v_acc); - arg.c_g_m_n_(g, m, n) = v_c; + arg.c_g_m_n_(g, m, n) = ck::type_convert(v_c); }; make_ParallelTensorFunctor(f_gmk_gkn_gmn, @@ -132,4 +135,3 @@ struct ReferenceBatchedGemm : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp index c6a530476645d7313bd24ab863fa16d25f2e4272..cde072578991bd5441a8cb2803b5769e7f05ce45 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -1,33 +1,13 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp index 4203085dbc67724e25db8e2e70893246f1a03944..6cab5f28f47766566bed3e9dc5028a6d38641485 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 45fc8b8503466f89c0fc2abbe71497f201bec6fa..1239ca163afd085c1197f7859991dd61b5274824 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -1,10 +1,14 @@ -#ifndef REFERENCE_CONV_BWD_DATA_HPP -#define REFERENCE_CONV_BWD_DATA_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -106,9 +110,8 @@ struct ReferenceConvBwdData : public device::BaseOperator } } - float v_in; - arg.in_element_op_(v_in, v_acc); - arg.input_(n, c, wi) = ck::type_convert(v_in); + arg.in_element_op_(v_acc, v_acc); + arg.input_(n, c, wi) = ck::type_convert(v_acc); }; make_ParallelTensorFunctor(f_ncw, @@ -352,4 +355,3 @@ struct ReferenceConvBwdData : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index d1afa898e406fa57799a99db4d6c7fb36aea6636..fc333fbd6a089d34cf57b0f666098d0d9e4ad36a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -1,12 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include #include -#include "stream_config.hpp" -#include "device_base.hpp" -#include "host_tensor.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp index 4be6169c1504a2c5b2293b09d30ada3ad123d1d4..9309ef6e8f6327d248371b1ce4d41d2a4641af92 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -1,10 +1,13 @@ -#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP -#define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -187,4 +190,3 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp index 466537c686aa7c8b9656ccdaf7756f50559589fd..44fa3520240c55e7ed2d9e7099fed7bddd903dad 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -1,10 +1,13 @@ -#ifndef REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP -#define REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -195,4 +198,3 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 6f097c6debb4a290add59ee9e789e3f309a5adfc..e3dd4de5dfd32ef530b9e810182ca55edf2ecd04 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -1,8 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -58,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - AccDataType v_a; - AccDataType v_b; + ADataType v_a; + BDataType v_b; - arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); - v_acc += v_a * v_b; + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); } AccDataType v_c; arg.c_element_op_(v_c, v_acc); - arg.c_m_n_(m, n) = v_c; + arg.c_m_n_(m, n) = ck::type_convert(v_c); }; make_ParallelTensorFunctor( diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp index 3e7f220e03d752102224f449a105ebcdb4ce0b85..cd3383b9945c4af15ff51ce8d499dfad1a40e560 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp @@ -1,10 +1,13 @@ -#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP -#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -66,8 +69,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator for(int k = 0; k < K; ++k) { - arg.a_element_op_(a, arg.a_m_k_(m, k)); - arg.b_element_op_(b, arg.b_k_n_(k, n)); + arg.a_element_op_(a, ck::type_convert(arg.a_m_k_(m, k))); + arg.b_element_op_(b, ck::type_convert(arg.b_k_n_(k, n))); acc += a * b; } @@ -131,4 +134,3 @@ struct ReferenceGemmBias2D : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp index 60f72e9e510eb64f2131d6e13d2abfcc49a273fc..33d7cbb83721e7defa4ce2008934e6344d62e3df 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp @@ -1,10 +1,14 @@ -#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_HPP -#define REFERENCE_GEMM_BIAS_ACTIVATION_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -134,4 +138,3 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp index 5e0ec75e5e8b35d9b267d997a384b81e5c369104..1ae63d2f86a88895e4135221fd9813060835c0b8 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp @@ -1,10 +1,14 @@ -#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP -#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include -#include "device_base.hpp" -#include "host_tensor.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace tensor_operation { @@ -142,4 +146,3 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b1e72459fd82a0f2f1c0b85f0be5cb952888e57b --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct ReferenceGemmLayernorm : public device::BaseOperator +{ + using ReferenceGemmInstance = ReferenceGemm; + + template + static void RunLayernorm(Tensor& result, + const Tensor& acc, // MxN + const Tensor& gamma, // 1xN + const Tensor& beta, // 1xN + const InDataType epsilon = 1e-5) + { + assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] && + acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]); + + size_t M = acc.mDesc.GetLengths()[0]; + size_t N = acc.mDesc.GetLengths()[1]; + + Tensor avg_acc_sq(HostTensorDescriptor(std::vector({M}))); + Tensor avg_acc(HostTensorDescriptor(std::vector({M}))); + Tensor acc_layernorm(acc); + + // reduce N dim + for(size_t i = 0; i < M; i++) + { + ComputeDataType sum_acc_sq = 0; + ComputeDataType sum_acc = 0; + for(size_t j = 0; j < N; j++) + { + sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j); + sum_acc += acc_layernorm(i, j); + } + avg_acc_sq(i) = sum_acc_sq / N; + avg_acc(i) = sum_acc / N; + } + + // normalize + acc_layernorm.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) = + (self(idx[0], idx[1]) - avg_acc(idx[0])) / + sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon); + }); + + // affine + acc_layernorm.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]); + }); + + // cast + result = acc_layernorm.template CopyAsType(); + } + + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n_bias, // 1xN + const Tensor& c0_m_n_add, // MxN + const Tensor& c0_n_gamma, // 1xN + const Tensor& c0_n_beta, // 1xN + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + const CDataType epsilon = 1e-5) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_bias_{c0_n_bias}, + c0_m_n_add_{c0_m_n_add}, + c0_n_gamma_{c0_n_gamma}, + c0_n_beta_{c0_n_beta}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op}, + epsilon_{epsilon} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_bias_; + const Tensor& c0_m_n_add_; + const Tensor& c0_n_gamma_; + const Tensor& c0_n_beta_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + + const CDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + // using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + Tensor acc_m_n(arg.c_m_n_.mDesc); + acc_m_n.GenerateTensorValue(GeneratorTensor_1{0}); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_, + arg.b_k_n_, + acc_m_n, + arg.a_element_op_, + arg.b_element_op_, + element_wise::PassThrough{}); + + // gemm + ref_invoker.Run(ref_argument); + + // activation(acc + bias) + acc_m_n.ForEach([&](auto& self, auto idx) { + AccDataType out; + arg.acc_element_op_(out, acc_m_n(idx[0], idx[1]) + arg.c0_n_bias_(idx[1])); + self(idx[0], idx[1]) = out; + }); + + // add from other layers + acc_m_n.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) += arg.c0_m_n_add_(idx[0], idx[1]); + }); + + // layernorm + RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_gamma_, arg.c0_n_beta_); + + // elementwise op + arg.c_m_n_.ForEach([&](auto& self, auto idx) { + arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1])); + }); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n_bias, // 1xN + const Tensor& c0_m_n_add, // 1xN + const Tensor& c0_n_gamma, // 1xN + const Tensor& c0_n_beta, // 1xN + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + const CDataType epsilon = 1e-5) + { + return Argument{a_m_k, + b_k_n, + c_m_n, + c0_n_bias, + c0_m_n_add, + c0_n_gamma, + c0_n_beta, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + epsilon}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5d9e90f71abe96c069a6ef0ce1fc08bba02d955f --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceSoftmax : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const std::vector sm_reduce_dims) + : in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) + { + // std::cout << "debug: scalar dims: "; + for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++) + { + if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == + sm_reduce_dims.end()) + { + sm_scalar_dims_.push_back(i); + // std::cout << i << ", "; + } + } + // std::cout << std::endl; + } + + const Tensor& in_; + Tensor& out_; + AccDataType alpha_; + AccDataType beta_; + std::vector sm_reduce_dims_; + std::vector sm_scalar_dims_; // dim after internal max/sum reduction + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + std::vector scalar_lengths; + for(index_t dim : arg.sm_scalar_dims_) + { + scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); + } + + Tensor reduce_max(scalar_lengths); + reduce_max.GenerateTensorValue( + GeneratorTensor_1{std::numeric_limits::lowest()}); + Tensor reduce_sum(scalar_lengths); + reduce_sum.GenerateTensorValue(GeneratorTensor_1{0}); + + auto to_sm_scalar_idx = [&](auto idx) { + std::vector sm_scalar_idx; + for(index_t dim : arg.sm_scalar_dims_) + { + sm_scalar_idx.push_back(idx[dim]); + } + return sm_scalar_idx; + }; + + arg.in_.ForEach([&](auto& self, auto idx) { + reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)), + static_cast(self(idx))); + }); + + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; + + Tensor in_stable(arg.in_.mDesc); + in_stable.ForEach([&](auto& self, auto idx) { + // numerator = exp(x - max(x)) + self(idx) = std::exp(static_cast(arg.in_(idx)) - + reduce_max(to_sm_scalar_idx(idx))); + }); + + // LogRangeAsType(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl; + + in_stable.ForEach([&](auto& self, auto idx) { + // denominator = sum(exp(x - max(x))) + reduce_sum(to_sm_scalar_idx(idx)) += self(idx); + }); + + // LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") << + // std::endl; + + arg.out_.ForEach([&](auto& self, auto idx) { + self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + + arg.beta_ * self(idx); + }); + + // LogRangeAsType(std::cout << "out: ", arg.out_.mData, ",") << std::endl; + // reduction along reduce dims + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") + // << std::endl; + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const std::vector sm_reduce_dims) + { + return Argument{in, out, alpha, beta, sm_reduce_dims}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceSoftmax" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp index 120938f0722223516bc8d775efd377323a054575..df4fca6562755012200ff888677bcac1c90129a2 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef NAIVE_CONV_FWD_HPP #define NAIVE_CONV_FWD_HPP diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp similarity index 55% rename from library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp rename to library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp index 40fd7274ef94cc67e3442188ef1559ea99e52cc5..20df1b3616a016029f4083bd821b4e286b7b727c 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp @@ -1,14 +1,20 @@ -#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP -#define CK_DEVICE_OPERATION_INSTANCE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include +#pragma once + +#include +#include + +#include "ck/utility/functional2.hpp" namespace ck { namespace tensor_operation { namespace device { +namespace instance { -template -void add_device_operation_instances(std::vector>& op_instances, +template +void add_device_operation_instances(std::vector>& op_instances, const NewOpInstances& new_op_instances) { ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { @@ -16,11 +22,14 @@ void add_device_operation_instances(std::vector>& op using NewOpInstance = remove_cvref_t; + static_assert(std::is_base_of_v, + "wrong! NewOpInstance should be derived from BaseOp"); + op_instances.push_back(std::make_unique(new_op_instance)); }); } +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66230ac45c365db3444154908d32989555fb62aa --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// aliasing, for commonly used type +using F64 = double; +using F32 = float; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using EMPTY_TUPLE = ck::Tuple<>; + +using F16_TUPLE = ck::Tuple; +using F16_F16_TUPLE = ck::Tuple; + +using F32_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +template +struct DeviceOperationInstanceFactory; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0655fd92e44c5c6400c5a66e6ee9e21f6e50da71 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceBatchedGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9bb8e5ce525bdc387a501ec20ee65faec1a1fcbd --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances); + +// Contraction + Bilinear +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>> +{ + using DeviceOp = DeviceContractionMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6eb5b1d0cc40d90e5769edaacc6104ba47462a2b --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances); + +// Contraction + Scale +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>> +{ + using DeviceOp = DeviceContractionMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a9cc8b79dd95eba5cc2e7d9bbfbcabbd881a3c91 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Normalize = ck::tensor_operation::element_wise::Normalize; +using DeviceNormalizeFromMeanMeanSquarePtr = + ck::tensor_operation::device::DeviceElementwisePtr<5, 1, 2, Normalize>; + +void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( + std::vector& instances); + +template +auto get_device_normalize_from_mean_meansquare_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value && is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); + } + + return op_ptrs; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..682f54675982e9338823a814c46075242534de1d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +template +auto get_device_gemm_add_add_mean_squaremean_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55ca8f42941d412ab25e03c28033c710ef8dc3b6 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -0,0 +1,383 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemm> +{ + using DeviceOp = DeviceGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e2cd64b34ee098c7cbffdae1486c7d80065c9bdb --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>&); + +// GEMM + Add + Add + FastGelu +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..37731fde06f66705cc5c777b4b2968d5aaf5e736 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +// GEMM + Bilinear +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>> +{ + using DeviceOp = DeviceGemmMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8986a793444c8f73679d892f4d55adb02bc61834 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmSplitK> +{ + using DeviceOp = DeviceGemmSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp index 6f0dbe75fff1ba1c39541bfe0ad541480db95164..97e9addfb9f479936352b651406189e85a9d638b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -1,26 +1,26 @@ -#ifndef DEVICE_REDUCE_INSTANTCE_HPP -#define DEVICE_REDUCE_INSTANTCE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_blockwise_f16_f16_f16.hpp" -#include "device_reduce_instance_blockwise_f16_f32_f16.hpp" -#include "device_reduce_instance_blockwise_f32_f32_f32.hpp" -#include "device_reduce_instance_blockwise_f32_f64_f32.hpp" -#include "device_reduce_instance_blockwise_f64_f64_f64.hpp" -#include "device_reduce_instance_blockwise_i8_i8_i8.hpp" -#include "device_reduce_instance_blockwise_i8_i32_i8.hpp" -#include "device_reduce_instance_blockwise_b16_f32_b16.hpp" -#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" -#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" -#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" -#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp" -#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" -#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" -#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" -#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" -#include "device_reduce_instance_threadwise_f32_f64_f32.hpp" -#include "device_reduce_instance_threadwise_f64_f64_f64.hpp" -#include "device_reduce_instance_threadwise_i8_i8_i8.hpp" -#include "device_reduce_instance_threadwise_i8_i32_i8.hpp" -#include "device_reduce_instance_threadwise_b16_f32_b16.hpp" +#pragma once -#endif +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp index e31d4e769eda28011fa527cfe0c3df13c55b8df7..5fd8c95f8424ba56537def4f2cca5159349d8a01 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp @@ -1,14 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_multiblock.hpp" +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { using reduce_configuration_1_instances_blockwise = std::tuple< // clang-format off @@ -61,10 +63,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple< >; #endif -template +template using deviceReduceBlockWisePtrType = DeviceReducePtr< - typename reduce_unary_operator::InElementwiseOperation, - typename reduce_unary_operator::AccElementwiseOperation>; + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; template void add_device_reduce_instance_blockwise( - std::vector>& device_op_instances) + std::vector>& device_op_instances) { - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; constexpr bool Indexable = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || @@ -137,7 +138,7 @@ void add_device_reduce_instance_blockwise( ReduceOpId, \ PropagateNan, \ UseIndex>( \ - std::vector> & device_op_instances) + std::vector> & device_op_instances) #define ADD_BLOCKWISE_INST_BY_ID( \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ @@ -150,21 +151,17 @@ void add_device_reduce_instance_blockwise( Rank, \ NumReduceDim) -#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_blockwise( \ - std::vector::InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) +#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_blockwise( \ + std::vector> & device_op_instances) #define ADD_BLOCKWISE_INST_REF_BY_ID( \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ @@ -177,10 +174,7 @@ void add_device_reduce_instance_blockwise( Rank, \ NumReduceDim) -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp index 3cad45f2e5d73dd792d5d58f840ff48205423fe1..8d1fed046a80457baad044dc8471bca40adcb4de 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -50,10 +53,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp index 441c1aec3ff25f9d26d174bfe18db354edaadbb8..ae7f13ce979bf25bc91383123bdd76ff7a5387ec 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -37,10 +40,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp index ca8532a458c8c155f84f20b12f3602deb012733f..c26e136593e4adef47b3b9b08e981922a1ede86c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -25,10 +28,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp index 64f504c9da528ee8aedc23bc88a3178846b40089..30064d588daf4c4ee562b6526c4396e25f74d8cd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -48,10 +52,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp index 9e84ee34fb355f713254dfaabfffde6f1f4908e0..c9f6a1a5ff8b1437491cf4e57d413a6cd9469bb5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -24,10 +28,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp index a37e3bdeb91af14679e9dfe3a2ea255b0f1b6438..c598e64cde7625d35653470d174b8129f4330af9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -48,10 +52,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp index 1d8695bbb0f158014e45d2eb49d11c668626dc14..cd1594992982ec77cd3c3460fbada555cd5cabc7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -20,10 +24,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp index b5c19b72072d550832113956eb54b03c5cdbe4e1..bf62f92ad89b2d8d163c4abc94602ac0ce34552e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_blockwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -36,10 +40,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp index 721d98a7189f3f876bc60a19cb057e655f68d70f..9fc409a08e238386b848b7bf139268966a89603e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp @@ -1,10 +1,12 @@ -#ifndef DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP -#define DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { template struct ReductionConfiguration_1 @@ -32,10 +34,7 @@ struct ReductionConfiguration_2 #define QUICK_REDUCE_TEST 1 -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp index 605109d0779f8b3f0aa546a27c3a39f4acdf22c1..a74e92ecab39c90b6473f978af0e7e511ee8c298 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp @@ -1,14 +1,17 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_multiblock.hpp" +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { using reduce_configuration_1_instances_multiblock_atomic_add = std::tuple< // clang-format off @@ -61,12 +64,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< >; #endif -template -using deviceReduceMultiBlockAtomicAddPtrType = - DeviceReducePtr:: - InElementwiseOperation, - typename reduce_unary_operator:: - AccElementwiseOperation>; +template +using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; template void add_device_reduce_instance_multiblock_atomic_add( - std::vector>& - device_op_instances) + std::vector>& device_op_instances) { - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; constexpr bool Indexable = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || @@ -158,8 +157,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ReduceOpId, \ PropagateNan, \ UseIndex>( \ - std::vector> & \ - device_op_instances) + std::vector> & device_op_instances) #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ @@ -172,21 +170,17 @@ void add_device_reduce_instance_multiblock_atomic_add( Rank, \ NumReduceDim) -#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_multiblock_atomic_add( \ - std::vector::InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_multiblock_atomic_add( \ + std::vector> & device_op_instances) #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ @@ -199,10 +193,7 @@ void add_device_reduce_instance_multiblock_atomic_add( Rank, \ NumReduceDim) -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp index 4e39cf49f6f0b121aa8e905c06bbd713c4803354..3efc5850685af3d7a0074f12262371f32e49c5ce 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_multiblock_atomic_add.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -21,10 +24,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp index 73424322ae2b095e970f1bd4e1aaa83a6af89156..804cba12cc42e1d0d446248943e80add5f789306 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_multiblock_atomic_add.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -21,10 +24,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp index ecc9c4ea8712478a92ce03b56710cae59648ab07..32eb843a1cca3ac9077410ab2088f56a1811fbbb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_multiblock_atomic_add.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -20,10 +24,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp index 41a60d5b70e26d34ecfae1234086fa7312be2599..9f2a89247501604f421f7a4c474acf5594520e4b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_multiblock_atomic_add.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -20,10 +24,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp index bdcca274d7ffb6bb7c598ad1b3c908d3c3b1ea05..bd20069992e2d7c14bcf5348f3a03168316ee30f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_multiblock_atomic_add.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -20,10 +24,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp index a2b4ae22bee116d2e5797eeece7efc61378280f3..6b84b25d0e27e105cb35edc50bab953fb04711aa 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp @@ -1,14 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_threadwise.hpp" +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { #ifdef QUICK_REDUCE_TEST using reduce_configuration_2_instances_threadwise = std::tuple< @@ -47,10 +49,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple< >; #endif -template +template using deviceReduceThreadWisePtrType = DeviceReducePtr< - typename reduce_unary_operator::InElementwiseOperation, - typename reduce_unary_operator::AccElementwiseOperation>; + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; template void add_device_reduce_instance_threadwise( - std::vector>& device_op_instances) + std::vector>& device_op_instances) { - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; constexpr bool Indexable = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || @@ -114,7 +115,7 @@ void add_device_reduce_instance_threadwise( ReduceOpId, \ PropagateNan, \ UseIndex>( \ - std::vector> & device_op_instances) + std::vector> & device_op_instances) #define ADD_THREADWISE_INST_BY_ID( \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ @@ -127,21 +128,17 @@ void add_device_reduce_instance_threadwise( Rank, \ NumReduceDim) -#define ADD_THREADWISE_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_threadwise( \ - std::vector::InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) +#define ADD_THREADWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_threadwise( \ + std::vector> & device_op_instances) #define ADD_THREADWISE_INST_REF_BY_ID( \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ @@ -154,10 +151,7 @@ void add_device_reduce_instance_threadwise( Rank, \ NumReduceDim) -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp index 0291f33214606771a0c2d75fd066cacbaa19cad7..5f7f5c7af5de44adfa9b266375d1226ab891f417 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -50,10 +53,7 @@ ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp index 7ab1bebc5f7c76715bc3c587ef0272128def4363..3c21b408cce5050fee79939ee330c315792ddfb6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -37,10 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp index 39c3d10660952059d82fa0aa95841db9587ae894..cd116986d998b446c07e12b5fbee7f7ff10afb04 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp @@ -1,13 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -25,10 +28,7 @@ ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp index 3c47bfd1898d3bd76dc61bc1e29c4aa80d778d5c..a764735fa98a32d567de46feb61725192c2fa8f7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -48,10 +52,7 @@ ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp index 9df9f6f1faf773f143b012874639cd4aa556b426..7d47c79f8476b4979aa1ee53a1611901d8fd7470 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -24,10 +28,7 @@ ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp index 00ab218f2069409b9a46c072a29c4c856958d7e4..faced808a263e5abfc8102974dbb824f26a046ea 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -48,10 +52,7 @@ ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp index de7445b0437e7425aba01cb37fd78155483072cd..111ba7a0cf46fcf9f0b0abfb9327e7c2b9d747f8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -20,10 +24,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp index 1ea1ee745e7f477db5c949c49b81f204360149a7..c771f057d610ecbd77f087709ab2e7af74fad56b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp @@ -1,12 +1,16 @@ -#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP -#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "device_reduce_instance_threadwise.hpp" +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -36,10 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck - -#endif diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 7cd6cc34c9d519019bf154787bed527474ac4406..0b82ba4357fc1973b76e4a24c90facfefbfbca71 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,195 +1,210 @@ -#ifndef CHECK_ERR_HPP -#define CHECK_ERR_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "data_type.hpp" - -namespace ck { -namespace utils { - -template -typename std::enable_if::value && !std::is_same::value, - bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-5, - double atol = 3e-6) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - bool res{true}; - int err_count = 0; - double err = 0; - double max_err = std::numeric_limits::min(); - for(std::size_t i = 0; i < ref.size(); ++i) - { - err = std::abs(out[i] - ref[i]); - if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) - { - max_err = err > max_err ? err : max_err; - err_count++; - if(err_count < 5) - { - std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" - << i << "]: " << out[i] << " != " << ref[i] << std::endl - << msg << std::endl; - } - res = false; - } - } - if(!res) - { - std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; - } - return res; -} - -template -typename std::enable_if::value, bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - bool res{true}; - int err_count = 0; - double err = 0; - // TODO: This is a hack. We should have proper specialization for bhalf_t data type. - double max_err = std::numeric_limits::min(); - for(std::size_t i = 0; i < ref.size(); ++i) - { - double o = type_convert(out[i]); - double r = type_convert(ref[i]); - err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) - { - max_err = err > max_err ? err : max_err; - err_count++; - if(err_count < 5) - { - std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" - << i << "]: " << o << " != " << r << std::endl - << msg << std::endl; - } - res = false; - } - } - if(!res) - { - std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; - } - return res; -} - -template -typename std::enable_if::value || std::is_same::value, - bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - bool res{true}; - int err_count = 0; - double err = 0; - double max_err = std::numeric_limits::min(); - for(std::size_t i = 0; i < ref.size(); ++i) - { - double o = type_convert(out[i]); - double r = type_convert(ref[i]); - err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) - { - max_err = err > max_err ? err : max_err; - err_count++; - if(err_count < 5) - { - std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" - << i << "]: " << o << " != " << r << std::endl - << msg << std::endl; - } - res = false; - } - } - if(!res) - { - std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; - } - return res; -} - -template -typename std::enable_if::value && !std::is_same::value, bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg = "Error: Incorrect results!", - double = 0, - double = 0) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - for(std::size_t i = 0; i < ref.size(); ++i) - { - if(out[i] != ref[i]) - { - std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) - << " != " << static_cast(ref[i]) << std::endl - << msg << std::endl; - return false; - } - } - return true; -} - -} // namespace utils -} // namespace ck - -template -std::ostream& operator<<(std::ostream& os, const std::vector& v) -{ - std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); - return os; -} - -#endif +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace utils { + +template +typename std::enable_if::value && !std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + err = std::abs(out[i] - ref[i]); + if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out[i] << " != " << ref[i] << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + // TODO: This is a hack. We should have proper specialization for bhalf_t data type. + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double o = type_convert(out[i]); + double r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double o = type_convert(out[i]); + double r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value && !std::is_same::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double atol = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + int64_t o = out[i]; + int64_t r = ref[i]; + err = std::abs(o - r); + + if(err > atol) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << "max err: " << max_err << std::endl; + } + return res; +} + +} // namespace utils +} // namespace ck + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} diff --git a/library/include/ck/library/utility/conv_util.hpp b/library/include/ck/library/utility/conv_util.hpp index c881b8970564a7287bb9d50ab74c1d43fa177ad6..e57bde8adde50b1b76e803a6ac15d55a118502ab 100644 --- a/library/include/ck/library/utility/conv_util.hpp +++ b/library/include/ck/library/utility/conv_util.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include @@ -9,17 +12,17 @@ #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "device_conv_fwd.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "fill.hpp" -#include "host_tensor.hpp" -#include "op_instance_engine.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/op_instance_engine.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" namespace ck { namespace tensor_operation { @@ -28,15 +31,15 @@ namespace device { using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; -namespace device_conv1d_fwd_instance { +namespace instance { void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector&); -} // namespace device_conv1d_fwd_instance -namespace device_conv2d_fwd_instance { +} // namespace instance +namespace instance { void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); @@ -45,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); -} // namespace device_conv2d_fwd_instance -namespace device_conv3d_fwd_instance { +} // namespace instance +namespace instance { void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector&); -} // namespace device_conv3d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation @@ -292,17 +295,17 @@ struct ConvolutionFwdInstances std::vector conv_ptrs; if constexpr(NumDimSpatial == 1) { - ck::tensor_operation::device::device_conv1d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 2) { - ck::tensor_operation::device::device_conv2d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 3) { - ck::tensor_operation::device::device_conv3d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); } return conv_ptrs; @@ -319,20 +322,20 @@ struct ConvolutionFwdInstances std::vector conv_ptrs; if constexpr(NumDimSpatial == 1) { - ck::tensor_operation::device::device_conv1d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); return conv_ptrs; } else if constexpr(NumDimSpatial == 2) { - ck::tensor_operation::device::device_conv2d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - ck::tensor_operation::device::device_conv2d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 3) { - ck::tensor_operation::device::device_conv3d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); } return conv_ptrs; @@ -349,17 +352,17 @@ struct ConvolutionFwdInstances std::vector conv_ptrs; if constexpr(NumDimSpatial == 1) { - ck::tensor_operation::device::device_conv1d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 2) { - ck::tensor_operation::device::device_conv2d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 3) { - ck::tensor_operation::device::device_conv3d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); } return conv_ptrs; @@ -376,17 +379,17 @@ struct ConvolutionFwdInstances std::vector conv_ptrs; if constexpr(NumDimSpatial == 1) { - ck::tensor_operation::device::device_conv1d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 2) { - ck::tensor_operation::device::device_conv2d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); } else if constexpr(NumDimSpatial == 3) { - ck::tensor_operation::device::device_conv3d_fwd_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); } return conv_ptrs; @@ -402,8 +405,8 @@ template , - typename WeightsInitFun = FillUniform> + typename InputInitFun = FillUniformDistribution, + typename WeightsInitFun = FillUniformDistribution> class ConvFwdOpInstance : public ck::utils::OpInstance { using DeviceConvFwdOp = tensor_operation::device:: @@ -422,8 +425,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance output_spatial_lengths_; const bool do_init_; - const InputInitFun& input_init_f_; - const WeightsInitFun& weights_init_f_; + InputInitFun input_init_f_; + WeightsInitFun weights_init_f_; }; } // namespace conv diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index f44aec969d3dd56cf5d26ec8c0e44078b1503482..6a76442779e416c2f206e573cdfb8ad112564a57 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,50 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include +#include #include -#include "data_type.hpp" +#include "ck/utility/data_type.hpp" namespace ck { namespace utils { -// template -// struct FillUniform; +template +struct FillUniformDistribution +{ + float a_{-5.f}; + float b_{5.f}; -// TODO: what's wrong with this specialization??? -// err: segmentation fault in mt19937 - infinite loop like. -// template -// struct FillUniform::value && -// !std::is_same::value>::type> -// { -// int a_{0}; -// int b_{5}; -// // T a_ = T{0}; -// // T b_ = T{5}; + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } +}; -// template -// void operator()(ForwardIter first, ForwardIter last) const -// { -// std::mt19937 gen{11939}; -// std::uniform_int_distribution dis(a_, b_); -// std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); -// } -// }; +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; -// struct FillUniform::value || -// std::is_same::value>::type> +// Workaround for uniform_int_distribution not working as expected. See note above.< template -struct FillUniform +struct FillUniformDistributionIntegerValue { - float a_{0}; - float b_{5}; + float a_{-5.f}; + float b_{5.f}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen{11939}; - std::uniform_real_distribution<> dis(a_, b_); - std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); } }; diff --git a/library/include/ck/library/utility/op_instance_engine.hpp b/library/include/ck/library/utility/op_instance_engine.hpp index 5429f66d3ed7a382ea56263469bb98027619a286..8ba63f36e2e6a013bd1552046002ae749db4d7d4 100644 --- a/library/include/ck/library/utility/op_instance_engine.hpp +++ b/library/include/ck/library/utility/op_instance_engine.hpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include +#include #include #include #include @@ -8,9 +12,12 @@ #include #include -#include "check_err.hpp" -#include "device_base.hpp" -#include "functional2.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" namespace ck { namespace utils { @@ -78,7 +85,8 @@ class OpInstanceRunEngine template > OpInstanceRunEngine(const OpInstanceT& op_instance, - const ReferenceOp& reference_op = ReferenceOp{}) + const ReferenceOp& reference_op = ReferenceOp{}, + bool do_verification = true) : op_instance_{op_instance} { in_tensors_ = op_instance_.GetInputTensors(); @@ -88,8 +96,11 @@ class OpInstanceRunEngine const Tensor&..., Tensor&>) { - ref_output_ = op_instance_.GetOutputTensor(); - CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + if(do_verification) + { + ref_output_ = op_instance_.GetOutputTensor(); + CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + } } AllocateDeviceInputTensors(std::make_index_sequence{}); out_device_buffer_ = @@ -110,6 +121,7 @@ class OpInstanceRunEngine op_ptr.get(), in_device_buffers_, out_device_buffer_); if(op_ptr->IsSupportedArgument(argument.get())) { + std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl; invoker->Run(argument.get()); out_device_buffer_->FromDevice(out_tensor_->mData.data()); if(!ref_output_) @@ -119,9 +131,16 @@ class OpInstanceRunEngine " You have to provide reference function."); } // TODO: enable flexible use of custom check_error functions - res = res && check_err(out_tensor_->mData, ref_output_->mData); + bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData); + std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl; + res = res && inst_res; out_device_buffer_->SetZero(); } + else + { + std::cout << "Given conv problem is not supported by instance: \n\t>>>>" + << op_ptr->GetTypeString() << std::endl; + } } return res; } @@ -132,7 +151,6 @@ class OpInstanceRunEngine bool do_verification = false, bool do_log = false) { - bool res{true}; ProfileBestConfig best_config; for(auto& op_ptr : op_ptrs) @@ -153,7 +171,7 @@ class OpInstanceRunEngine std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; - if(tflops < best_config.best_tflops) + if(avg_time < best_config.best_avg_time) { best_config.best_op_name = op_name; best_config.best_tflops = tflops; @@ -171,7 +189,7 @@ class OpInstanceRunEngine " You have to provide reference function."); } // TODO: enable flexible use of custom check_error functions - res = res && CheckErr(out_tensor_->mData, ref_output_->mData); + CheckErr(out_tensor_->mData, ref_output_->mData); if(do_log) {} } @@ -223,7 +241,7 @@ class OpInstanceRunEngine template bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const { - return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_); + return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_); } }; diff --git a/library/src/host_tensor/CMakeLists.txt b/library/src/host_tensor/CMakeLists.txt index 2a020b763dcb7b0ef482bc49ec04c86d16cadf73..eca22c6091fed2c937527a8090f65d84419b8cbe 100644 --- a/library/src/host_tensor/CMakeLists.txt +++ b/library/src/host_tensor/CMakeLists.txt @@ -1,12 +1,6 @@ ## host_tensor -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/include/ck - ${PROJECT_SOURCE_DIR}/include/ck/utility - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor -) - set(HOST_TENSOR_SOURCE - device.cpp + device_memory.cpp host_tensor.cpp ) @@ -17,22 +11,20 @@ target_compile_features(host_tensor PUBLIC) set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(host_tensor SYSTEM PUBLIC $) -target_include_directories(host_tensor PUBLIC +target_include_directories(host_tensor PUBLIC "$" "$" "$" ) -install(TARGETS host_tensor - EXPORT host_tensorTargets - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +rocm_install( + TARGETS host_tensor + EXPORT host_tensorTargets ) -install(EXPORT host_tensorTargets - FILE composable_kernelhost_tensorTargets.cmake +rocm_install( + EXPORT host_tensorTargets + FILE composable_kernelhost_tensorTargets.cmake NAMESPACE composable_kernel:: DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) diff --git a/library/src/host_tensor/device.cpp b/library/src/host_tensor/device.cpp deleted file mode 100644 index a50c8118fe397142b1b28c1883ae2f70323c8e12..0000000000000000000000000000000000000000 --- a/library/src/host_tensor/device.cpp +++ /dev/null @@ -1,137 +0,0 @@ -#include -#include -#include -#include -#include "device.hpp" - -#ifndef CK_NOGPU -DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) -{ - hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); -} - -void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } - -std::size_t DeviceMem::GetBufferSize() { return mMemSize; } - -void DeviceMem::ToDevice(const void* p) -{ - hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); -} - -void DeviceMem::FromDevice(void* p) -{ - hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); -} - -void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } - -DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } - -struct KernelTimerImpl -{ - KernelTimerImpl() - { - hip_check_error(hipEventCreate(&mStart)); - hip_check_error(hipEventCreate(&mEnd)); - } - - ~KernelTimerImpl() - { - hip_check_error(hipEventDestroy(mStart)); - hip_check_error(hipEventDestroy(mEnd)); - } - - void Start() - { - hip_check_error(hipDeviceSynchronize()); - hip_check_error(hipEventRecord(mStart, nullptr)); - } - - void End() - { - hip_check_error(hipEventRecord(mEnd, nullptr)); - hip_check_error(hipEventSynchronize(mEnd)); - } - - float GetElapsedTime() const - { - float time; - hip_check_error(hipEventElapsedTime(&time, mStart, mEnd)); - return time; - } - - hipEvent_t mStart, mEnd; -}; - -KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} - -KernelTimer::~KernelTimer() {} - -void KernelTimer::Start() { impl->Start(); } - -void KernelTimer::End() { impl->End(); } - -float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } -#endif - -DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment) - : mMemSize(mem_size), mAlignment(alignment) -{ - if(mem_size == 0) - { - mpDeviceBuf = nullptr; - } - else - { - assert(!(alignment == 0 || (alignment & (alignment - 1)))); // check pow of 2 - - // TODO: posix only - int rtn = posix_memalign(&mpDeviceBuf, alignment, mem_size); - - assert(rtn == 0); - } -} - -void* DeviceAlignedMemCPU::GetDeviceBuffer() { return mpDeviceBuf; } - -std::size_t DeviceAlignedMemCPU::GetBufferSize() { return mMemSize; } - -void DeviceAlignedMemCPU::ToDevice(const void* p) { memcpy(mpDeviceBuf, p, mMemSize); } - -void DeviceAlignedMemCPU::FromDevice(void* p) { memcpy(p, mpDeviceBuf, mMemSize); } - -void DeviceAlignedMemCPU::SetZero() { memset(mpDeviceBuf, 0, mMemSize); } - -DeviceAlignedMemCPU::~DeviceAlignedMemCPU() -{ - if(mpDeviceBuf != nullptr) - free(mpDeviceBuf); -} - -struct WallTimerImpl -{ - void Start() { mStart = std::chrono::high_resolution_clock::now(); } - - void End() { mStop = std::chrono::high_resolution_clock::now(); } - - float GetElapsedTime() const - { - return static_cast( - std::chrono::duration_cast(mStop - mStart).count()) * - 1e-3; - } - - std::chrono::time_point mStart; - std::chrono::time_point mStop; -}; - -WallTimer::WallTimer() : impl(new WallTimerImpl()) {} - -WallTimer::~WallTimer() {} - -void WallTimer::Start() { impl->Start(); } - -void WallTimer::End() { impl->End(); } - -float WallTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } diff --git a/library/src/host_tensor/device_memory.cpp b/library/src/host_tensor/device_memory.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e7157e4e0fe53364b36c16093ffb574d17d1749 --- /dev/null +++ b/library/src/host_tensor/device_memory.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/device_utility/hip_check_error.hpp" +#include "ck/library/host_tensor/device_memory.hpp" + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +std::size_t DeviceMem::GetBufferSize() { return mMemSize; } + +void DeviceMem::ToDevice(const void* p) +{ + hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } + +DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } diff --git a/library/src/host_tensor/host_tensor.cpp b/library/src/host_tensor/host_tensor.cpp index 138e3fc2549f2e0ce1c26290f2646a45d19fb855..dc9f5699dcb69782c2bcbab9ef58ba1a434fe1e3 100644 --- a/library/src/host_tensor/host_tensor.cpp +++ b/library/src/host_tensor/host_tensor.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include -#include "host_tensor.hpp" + +#include "ck/library/host_tensor/host_tensor.hpp" void HostTensorDescriptor::CalculateStrides() { @@ -50,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) return os; } - -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) -{ - os << "dim " << desc.GetNumOfDimension() << ", "; - - os << "lengths {"; - LogRange(os, desc.GetLengths(), ", "); - os << "}, "; - - os << "strides {"; - LogRange(os, desc.GetStrides(), ", "); - os << "}" << std::endl; -} - -#if 1 -// FIXME: remove -void bf16_to_f32_(const Tensor& src, Tensor& dst) -{ - for(std::size_t i = 0; i < src.mData.size(); ++i) - dst.mData[i] = ck::type_convert(src.mData[i]); -} -#endif diff --git a/library/src/obselete_driver_offline/CMakeLists.txt b/library/src/obselete_driver_offline/CMakeLists.txt deleted file mode 100644 index 54b13953279e2173bc2f13c2dc568addfc74008e..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/CMakeLists.txt +++ /dev/null @@ -1,37 +0,0 @@ -include_directories(BEFORE - include - ${PROJECT_SOURCE_DIR}/host/host_tensor/include - ${PROJECT_SOURCE_DIR}/host/device/include - ${PROJECT_SOURCE_DIR}/host/solver/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation - ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform - ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver - ${PROJECT_SOURCE_DIR}/external/rocm/include -) - -set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) -set(CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_fwd_driver_offline_nchwc.cpp) -set(CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_add_fwd_driver_offline_nchwc.cpp) -set(CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_maxpool_fwd_driver_offline_nchwc.cpp) -set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) -set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) -set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) - -add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) -add_executable(conv_fwd_driver_offline_nchwc ${CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE}) -add_executable(conv_add_fwd_driver_offline_nchwc ${CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE}) -add_executable(conv_maxpool_fwd_driver_offline_nchwc ${CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE}) -add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) -add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) -add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) - -target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) -target_link_libraries(conv_fwd_driver_offline_nchwc PRIVATE host_tensor) -target_link_libraries(conv_add_fwd_driver_offline_nchwc PRIVATE host_tensor) -target_link_libraries(conv_maxpool_fwd_driver_offline_nchwc PRIVATE host_tensor) -target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) -target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) -target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp deleted file mode 100644 index a7541f03de80a4d90f8b1bdf34bf2e41ec2fe5ba..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp +++ /dev/null @@ -1,416 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "device_tensor.hpp" -#include "device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" - -#define USE_DYNAMIC_MODE 0 -#define USE_CONV_FWD_V5R1_NCHWC 1 - -enum ConvForwardAlgo -{ - V5R1NCHWC // 0 -}; - -template -void host_direct_convolution_add_nchwc(const Tensor& in, - const Tensor& wei, - const Tensor& add, - const Tensor& bias, - Tensor& add_host, - Tensor& out_host, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ck::ActivTypeEnum activ_type) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { - double v = 0; - auto k = k0 * out_host.mDesc.GetLengths()[4] + k1; - - for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - - for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1) - { - v += static_cast(in(n, c0, hi, wi, c1)) * - static_cast(wei(k, c0, y, x, c1)); - } - } - } - } - } - - v += bias(k0, k1); - v = activ(v, activ_type); - - const int hox2 = ho * 2; - const int wox2 = wo * 2; - - out_host(n, k0, ho, wo, k1) = v; - - add_host(n, k0, hox2, wox2, k1) = v + add(n, k0, hox2, wox2, k1); - add_host(n, k0, hox2, wox2 + 1, k1) = v + add(n, k0, hox2, wox2 + 1, k1); - add_host(n, k0, hox2 + 1, wox2, k1) = v + add(n, k0, hox2 + 1, wox2, k1); - add_host(n, k0, hox2 + 1, wox2 + 1, k1) = v + add(n, k0, hox2 + 1, wox2 + 1, k1); - }; - - make_ParallelTensorFunctor(f_nchw, - out_host.mDesc.GetLengths()[0], - out_host.mDesc.GetLengths()[1], - out_host.mDesc.GetLengths()[2], - out_host.mDesc.GetLengths()[3], - out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); -} - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 23) - { - printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; - - const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); - - const index_t N = std::stoi(argv[6]); - const index_t K0 = std::stoi(argv[7]); - const index_t K1 = std::stoi(argv[8]); - const index_t C0 = std::stoi(argv[9]); - const index_t C1 = std::stoi(argv[10]); - const index_t Y = std::stoi(argv[11]); - const index_t X = std::stoi(argv[12]); - const index_t Hi = std::stoi(argv[13]); - const index_t Wi = std::stoi(argv[14]); - - const index_t conv_stride_h = std::stoi(argv[15]); - const index_t conv_stride_w = std::stoi(argv[16]); - const index_t conv_dilation_h = std::stoi(argv[17]); - const index_t conv_dilation_w = std::stoi(argv[18]); - const index_t in_left_pad_h = std::stoi(argv[19]); - const index_t in_left_pad_w = std::stoi(argv[20]); - const index_t in_right_pad_h = std::stoi(argv[21]); - const index_t in_right_pad_w = std::stoi(argv[22]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - const auto Hox2 = Ho * 2; - const auto Wox2 = Wo * 2; -#else - // static mode - if(argc < 6) - { - printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); - - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); - - constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; - -#if 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K1 = Number<8>{}; - constexpr auto K0 = Number<8>{}; -#elif 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<540>{}; - constexpr auto Wi = Number<960>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<270>{}; - constexpr auto Wi = Number<480>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 1 - constexpr auto N = Number<128>{}; - constexpr auto Hi = Number<135>{}; - constexpr auto Wi = Number<240>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<32>{}; - constexpr auto Wi = Number<32>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K1 = Number<8>{}; - constexpr auto K0 = Number<8>{}; -#endif - - constexpr auto conv_stride_h = I1; - constexpr auto conv_stride_w = I1; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; - - constexpr auto Hox2 = Number{}; - constexpr auto Wox2 = Number{}; - -#endif - -#if 0 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#elif 1 - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - std::vector in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5), - add_lengths_host(5), bias_lengths_host(2); - - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C0); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - in_lengths_host[4] = static_cast(C1); - - wei_lengths_host[0] = static_cast(K0 * K1); - wei_lengths_host[1] = static_cast(C0); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - wei_lengths_host[4] = static_cast(C1); - - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K0); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - out_lengths_host[4] = static_cast(K1); - - add_lengths_host[0] = static_cast(N); - add_lengths_host[1] = static_cast(K0); - add_lengths_host[2] = static_cast(Hox2); - add_lengths_host[3] = static_cast(Wox2); - add_lengths_host[4] = static_cast(K1); - - bias_lengths_host[0] = static_cast(K0); - bias_lengths_host[1] = static_cast(K1); - - Tensor in(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor add(add_lengths_host); - Tensor add_device(add_lengths_host); - Tensor add_host(add_lengths_host); - Tensor bias(bias_lengths_host); - Tensor out_host(out_lengths_host); - - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(add.mDesc, std::cout << "add: "); - - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - add.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - - auto f_make_for_device_nchwc = [&]() { - const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); - const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); - const auto add_lengths_dev = make_tuple(N, K0, Hox2, Wox2, K1); - const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - add_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_FWD_V5R1_NCHWC - if(algo == ConvForwardAlgo::V5R1NCHWC) - { - const auto tmp = f_make_for_device_nchwc(); - - device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( - tmp[I0], // in_lengths_dev - tmp[I1], // wei_lengths_dev - tmp[I2], // add_lengths_dev - tmp[I3], // out_lengths_dev - tmp[I4], // conv_strides_dev - tmp[I5], // conv_dilations_dev - tmp[I6], // in_left_pads_dev - tmp[I7], // in_right_pads_dev - in, - wei, - bias, - add, - add_device, - nrepeat); - } -#endif - - if(do_verification) - { - host_direct_convolution_add_nchwc(in, - wei, - add, - bias, - add_host, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - activ_type); - - ck::utils::check_err(add_device.mData, add_host.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "add_host: ", add_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "add_device: ", add_device.mData, ",") << std::endl; - } - } -} diff --git a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp deleted file mode 100644 index c4dcb7c08539b034b9df056adb9a53f5573e0cd4..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp +++ /dev/null @@ -1,488 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "device_tensor.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp" - -#define USE_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 0 -#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 - -enum ConvTensorLayout -{ - NCHW, - NHWC, - CHWN, - NCHWc, - NHWCc -}; - -enum ConvBackwardDataAlgo -{ - V4R1XDLNHWC, // 0 - V4R1R2XDLNHWC, // 1 -}; - -template -void host_convolution_backward_data(Tensor& in, - const Tensor& wei, - const Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& /* in_right_pads */, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I2]; - std::size_t X = wei.mDesc.GetLengths()[I3]; - - std::size_t Ho = out.mDesc.GetLengths()[I2]; - std::size_t Wo = out.mDesc.GetLengths()[I3]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, k, ho, wo) * wei(k, c, y, x); - } - } - } - } - } - } - } - - in(n, c, hi, wi) = v; - }; - - auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I1]; - std::size_t X = wei.mDesc.GetLengths()[I2]; - - std::size_t Ho = out.mDesc.GetLengths()[I1]; - std::size_t Wo = out.mDesc.GetLengths()[I2]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, ho, wo, k) * wei(k, y, x, c); - } - } - } - } - } - } - } - - in(n, hi, wi, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_MODE - // dynamic mode - if(argc != 22) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t N = std::stoi(argv[7]); - const index_t K = std::stoi(argv[8]); - const index_t C = std::stoi(argv[9]); - const index_t Y = std::stoi(argv[10]); - const index_t X = std::stoi(argv[11]); - const index_t Hi = std::stoi(argv[12]); - const index_t Wi = std::stoi(argv[13]); - - const index_t conv_stride_h = std::stoi(argv[14]); - const index_t conv_stride_w = std::stoi(argv[15]); - const index_t conv_dilation_h = std::stoi(argv[16]); - const index_t conv_dilation_w = std::stoi(argv[17]); - const index_t in_left_pad_h = std::stoi(argv[18]); - const index_t in_left_pad_w = std::stoi(argv[19]); - const index_t in_right_pad_h = std::stoi(argv[20]); - const index_t in_right_pad_w = std::stoi(argv[21]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 7) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - constexpr auto N = Number<128>{}; - constexpr auto C = Number<192>{}; - constexpr auto Hi = Number<71>{}; - constexpr auto Wi = Number<71>{}; - constexpr auto K = Number<256>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - - constexpr auto conv_stride_h = I2; - constexpr auto conv_stride_w = I2; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 0 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - if(layout == ConvTensorLayout::NCHW) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - } - else if(layout == ConvTensorLayout::NHWC) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - } - else - { - throw std::runtime_error("wrong! not implemented"); - } - - Tensor in_host(in_lengths_host); - Tensor in_device(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor out(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - auto f_make_for_device_nhwc = [&]() { -#if USE_MODE - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_BWD_V4R1_XDL_NHWC - if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); - } -#endif - -#if USE_CONV_BWD_V4R1R2_XDL_NHWC - if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 && - in_right_pad_w == 0) - { - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1< - in_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); - } - else - { -#if 1 - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); -#endif - } - } -#endif - - if(do_verification) - { - host_convolution_backward_data(in_host, - wei, - out, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - ck::utils::check_err(in_device.mData, in_host.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") << std::endl; - } - } -} diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp deleted file mode 100644 index ab8beec87bfff4d0cbd1c8fb6e5a8549833d3171..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp +++ /dev/null @@ -1,549 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "device_tensor.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" - -#define USE_DYNAMIC_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 0 -#define USE_CONV_FWD_V4R4R2_NHWC 0 -#define USE_CONV_FWD_V6R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 - -enum ConvTensorLayout -{ - NCHW, - NHWC, - CHWN, - NCHWc, - NHWCc -}; - -enum ConvForwardAlgo -{ - V4R4NCHW, // 0 - V4R4R2NHWC, // 1 - V6R1NCHW, // 2 - V4R4R2XDLNCHW, // 3 - V4R4R4XDLNHWC // 4 -}; - -template -void host_convolution_forward(const Tensor& in, - const Tensor& wei, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - if constexpr(is_same::value) - { - v += ck::type_convert(in(n, c, hi, wi)) * - ck::type_convert(wei(k, c, y, x)); - } - else - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(wei(k, c, y, x)); - } - } - } - } - } - - if constexpr(is_same::value) - { - out(n, k, ho, wo) = ck::type_convert(static_cast(v)); - } - else - { - out(n, k, ho, wo) = v; - } - }; - - auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - if constexpr(is_same::value) - { - v += ck::type_convert(in(n, hi, wi, c)) * - ck::type_convert(wei(k, y, x, c)); - } - else - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(wei(k, y, x, c)); - } - } - } - } - } - if constexpr(is_same::value) - { - out(n, ho, wo, k) = ck::type_convert(static_cast(v)); - } - else - { - out(n, ho, wo, k) = v; - } - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 22) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t N = std::stoi(argv[7]); - const index_t K = std::stoi(argv[8]); - const index_t C = std::stoi(argv[9]); - const index_t Y = std::stoi(argv[10]); - const index_t X = std::stoi(argv[11]); - const index_t Hi = std::stoi(argv[12]); - const index_t Wi = std::stoi(argv[13]); - - const index_t conv_stride_h = std::stoi(argv[14]); - const index_t conv_stride_w = std::stoi(argv[15]); - const index_t conv_dilation_h = std::stoi(argv[16]); - const index_t conv_dilation_w = std::stoi(argv[17]); - const index_t in_left_pad_h = std::stoi(argv[18]); - const index_t in_left_pad_w = std::stoi(argv[19]); - const index_t in_right_pad_h = std::stoi(argv[20]); - const index_t in_right_pad_w = std::stoi(argv[21]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 7) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - constexpr auto N = Number<128>{}; - constexpr auto C = Number<192>{}; - constexpr auto Hi = Number<71>{}; - constexpr auto Wi = Number<71>{}; - constexpr auto K = Number<256>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - - constexpr auto conv_stride_h = I1; - constexpr auto conv_stride_w = I1; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 1 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#elif 0 - using in_data_t = bhalf_t; - using acc_data_t = float; - using out_data_t = bhalf_t; -#elif 1 - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - if(layout == ConvTensorLayout::NCHW) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - } - else if(layout == ConvTensorLayout::NHWC) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - } - else - { - std::runtime_error("wrong! not implemented"); - } - - Tensor in(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor out_host(out_lengths_host); - Tensor out_device(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - auto f_make_for_device_nchw = [&]() { - const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); - const auto wei_lengths_dev = make_tuple(K, C, Y, X); - const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - - auto f_make_for_device_nhwc = [&]() { - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_FWD_V4R4_NCHW - if(algo == ConvForwardAlgo::V4R4NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R2_NHWC - if(algo == ConvForwardAlgo::V4R4R2NHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V6R1_NCHW - if(algo == ConvForwardAlgo::V6R1NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R2_XDL_NCHW - if(algo == ConvForwardAlgo::V4R4R2XDLNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R4_XDL_NHWC - if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - - if(do_verification) - { - host_convolution_forward(in, - wei, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - ck::utils::check_err(out_device.mData, out_host.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; - } - } -} diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp deleted file mode 100644 index 6fb8b4c2aa3b9c8251d340c8a448b7ca7a98a39c..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp +++ /dev/null @@ -1,393 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "device_tensor.hpp" -#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" - -#define USE_DYNAMIC_MODE 0 -#define USE_CONV_FWD_V5R1_NCHWC 1 - -enum ConvForwardAlgo -{ - V5R1NCHWC // 0 -}; - -template -void host_direct_convolution_nchwc(const Tensor& in, - const Tensor& wei, - const Tensor& bias, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ck::ActivTypeEnum activ_type) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { - double v = 0; - const int k = k0 * out.mDesc.GetLengths()[4] + k1; - - for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1) - { - v += static_cast(in(n, c0, hi, wi, c1)) * - static_cast(wei(k, c0, y, x, c1)); - } - } - } - } - } - v += bias(k0, k1); - out(n, k0, ho, wo, k1) = activ(v, activ_type); - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3], - out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); -} - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 23) - { - printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; - - const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); - - const index_t N = std::stoi(argv[6]); - const index_t K0 = std::stoi(argv[7]); - const index_t K1 = std::stoi(argv[8]); - const index_t C0 = std::stoi(argv[9]); - const index_t C1 = std::stoi(argv[10]); - const index_t Y = std::stoi(argv[11]); - const index_t X = std::stoi(argv[12]); - const index_t Hi = std::stoi(argv[13]); - const index_t Wi = std::stoi(argv[14]); - - const index_t conv_stride_h = std::stoi(argv[15]); - const index_t conv_stride_w = std::stoi(argv[16]); - const index_t conv_dilation_h = std::stoi(argv[17]); - const index_t conv_dilation_w = std::stoi(argv[18]); - const index_t in_left_pad_h = std::stoi(argv[19]); - const index_t in_left_pad_w = std::stoi(argv[20]); - const index_t in_right_pad_h = std::stoi(argv[21]); - const index_t in_right_pad_w = std::stoi(argv[22]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 6) - { - printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); - - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); - - // constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::Sigmoid; - constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; - -#if 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<1>{}; - constexpr auto K1 = Number<4>{}; -#elif 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<1>{}; - constexpr auto X = Number<1>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<540>{}; - constexpr auto Wi = Number<960>{}; - constexpr auto Y = Number<1>{}; - constexpr auto X = Number<1>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<128>{}; - constexpr auto Hi = Number<270>{}; - constexpr auto Wi = Number<480>{}; - constexpr auto Y = Number<1>{}; - constexpr auto X = Number<1>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#endif - - constexpr auto conv_stride_h = I1; - constexpr auto conv_stride_w = I1; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - -#if 1 - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; -#else - constexpr auto in_left_pad_h = I0; - constexpr auto in_left_pad_w = I0; - constexpr auto in_right_pad_h = I0; - constexpr auto in_right_pad_w = I0; -#endif - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 0 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#elif 1 - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - std::vector in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5), - bias_lengths_host(2); - - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C0); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - in_lengths_host[4] = static_cast(C1); - - wei_lengths_host[0] = static_cast(K0 * K1); - wei_lengths_host[1] = static_cast(C0); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - wei_lengths_host[4] = static_cast(C1); - - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K0); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - out_lengths_host[4] = static_cast(K1); - - bias_lengths_host[0] = static_cast(K0); - bias_lengths_host[1] = static_cast(K1); - - Tensor in(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor bias(bias_lengths_host); - Tensor out_host(out_lengths_host); - Tensor out_device(out_lengths_host); - - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(bias.mDesc, std::cout << "bias: "); - ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); - - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - auto f_make_for_device_nchwc = [&]() { - const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); - const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); - const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_FWD_V5R1_NCHWC - if(algo == ConvForwardAlgo::V5R1NCHWC) - { - const auto tmp = f_make_for_device_nchwc(); - - device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - bias, - out_device, - nrepeat); - } -#endif - - if(do_verification) - { - host_direct_convolution_nchwc(in, - wei, - bias, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - activ_type); - - ck::utils::check_err(out_device.mData, out_host.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "bias: ", bias.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; - } - } -} diff --git a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp deleted file mode 100644 index fb7e8e975b9444a1fe058c6acf47c83ecf166f5e..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp +++ /dev/null @@ -1,415 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "device_tensor.hpp" -#include "device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" - -#define USE_DYNAMIC_MODE 0 -#define USE_CONV_FWD_V5R1_NCHWC 1 - -enum ConvForwardAlgo -{ - V5R1NCHWC // 0 -}; - -template -void host_direct_convolution_maxpool_nchwc(const Tensor& in, - const Tensor& wei, - const Tensor& bias, - Tensor& out_host, - Tensor& max_host, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ck::ActivTypeEnum activ_type) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { - double v = 0; - auto k = k0 * out_host.mDesc.GetLengths()[4] + k1; - - for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1) - { - v += static_cast(in(n, c0, hi, wi, c1)) * - static_cast(wei(k, c0, y, x, c1)); - } - } - } - } - } - - v += bias(k0, k1); - v = activ(v, activ_type); - - out_host(n, k0, ho, wo, k1) = v; - }; - - make_ParallelTensorFunctor(f_nchw, - out_host.mDesc.GetLengths()[0], - out_host.mDesc.GetLengths()[1], - out_host.mDesc.GetLengths()[2], - out_host.mDesc.GetLengths()[3], - out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); - - auto maxpool_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { - auto hx = ho * 2; - auto wx = wo * 2; - - auto v0 = out_host(n, k0, hx, wx, k1); - auto v1 = out_host(n, k0, hx, wx + 1, k1); - auto v2 = out_host(n, k0, hx + 1, wx, k1); - auto v3 = out_host(n, k0, hx + 1, wx + 1, k1); - - max_host(n, k0, ho, wo, k1) = std::max({v0, v1, v2, v3}); - }; - - make_ParallelTensorFunctor(maxpool_nchw, - max_host.mDesc.GetLengths()[0], - max_host.mDesc.GetLengths()[1], - max_host.mDesc.GetLengths()[2], - max_host.mDesc.GetLengths()[3], - max_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); -} - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 23) - { - printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; - - const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); - - const index_t N = std::stoi(argv[6]); - const index_t K0 = std::stoi(argv[7]); - const index_t K1 = std::stoi(argv[8]); - const index_t C0 = std::stoi(argv[9]); - const index_t C1 = std::stoi(argv[10]); - const index_t Y = std::stoi(argv[11]); - const index_t X = std::stoi(argv[12]); - const index_t Hi = std::stoi(argv[13]); - const index_t Wi = std::stoi(argv[14]); - - const index_t conv_stride_h = std::stoi(argv[15]); - const index_t conv_stride_w = std::stoi(argv[16]); - const index_t conv_dilation_h = std::stoi(argv[17]); - const index_t conv_dilation_w = std::stoi(argv[18]); - const index_t in_left_pad_h = std::stoi(argv[19]); - const index_t in_left_pad_w = std::stoi(argv[20]); - const index_t in_right_pad_h = std::stoi(argv[21]); - const index_t in_right_pad_w = std::stoi(argv[22]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - const index_t Ho_2 = Ho / 2; - const index_t Wo_2 = Wo / 2; -#else - // static mode - if(argc < 6) - { - printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); - - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); - - constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; - -#if 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<3>{}; - constexpr auto C1 = Number<4>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<540>{}; - constexpr auto Wi = Number<960>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#elif 0 - constexpr auto N = Number<128>{}; - constexpr auto Hi = Number<270>{}; - constexpr auto Wi = Number<480>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; -#endif - - constexpr auto conv_stride_h = I1; - constexpr auto conv_stride_w = I1; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; - - constexpr auto Ho_2 = Number{}; - constexpr auto Wo_2 = Number{}; - -#endif - -#if 0 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#elif 1 - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - std::vector in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5), - max_lengths_host(5), bias_lengths_host(2); - - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C0); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - in_lengths_host[4] = static_cast(C1); - - wei_lengths_host[0] = static_cast(K0 * K1); - wei_lengths_host[1] = static_cast(C0); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - wei_lengths_host[4] = static_cast(C1); - - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K0); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - out_lengths_host[4] = static_cast(K1); - - max_lengths_host[0] = static_cast(N); - max_lengths_host[1] = static_cast(K0); - max_lengths_host[2] = static_cast(Ho_2); - max_lengths_host[3] = static_cast(Wo_2); - max_lengths_host[4] = static_cast(K1); - - bias_lengths_host[0] = static_cast(K0); - bias_lengths_host[1] = static_cast(K1); - - Tensor in(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor bias(bias_lengths_host); - Tensor out_device(out_lengths_host); - Tensor out_host(out_lengths_host); - Tensor max_device(max_lengths_host); - Tensor max_host(max_lengths_host); - - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - - auto f_make_for_device_nchwc = [&]() { - const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); - const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); - const auto max_lengths_dev = make_tuple(N, K0, Ho_2, Wo_2, K1); - const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - max_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_FWD_V5R1_NCHWC - if(algo == ConvForwardAlgo::V5R1NCHWC) - { - const auto tmp = f_make_for_device_nchwc(); - - device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1< - in_data_t, - acc_data_t, - out_data_t, - activ_type>(tmp[I0], // in_lengths_dev - tmp[I1], // wei_lengths_dev - tmp[I2], // max_lengths_dev - tmp[I3], // out_lengths_dev - tmp[I4], // conv_strides_dev - tmp[I5], // conv_dilations_dev - tmp[I6], // in_left_pads_dev - tmp[I7], // in_right_pads_dev - in, - wei, - bias, - out_device, - max_device, - nrepeat); - } -#endif - - if(do_verification) - { - host_direct_convolution_maxpool_nchwc(in, - wei, - bias, - out_host, - max_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - activ_type); - - ck::utils::check_err(out_device.mData, out_host.mData); - ck::utils::check_err(max_device.mData, max_host.mData); - - if(do_log) - { - // LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - // LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - // LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << - // std::endl; - LogRangeAsType(std::cout << "max_host: ", max_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "max_device: ", max_device.mData, ",") << std::endl; - } - } -} diff --git a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp deleted file mode 100644 index 1ac974202cadda3b103f275d81c064c62ff19e37..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp +++ /dev/null @@ -1,532 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "device_tensor.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" - -enum ConvTensorLayout -{ - NCHW, - NHWC, - CHWN, - NCHWc, - NHWCc -}; - -#define USE_DYNAMIC_MODE 1 -#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 -#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 -#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 -#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0 -#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 - -enum ConvBackwardWeightAlgo -{ - V4R4R2XDLNCHW, // 0 - V4R4R4XDLNHWC, // 1 - V4R4R2XDLATOMICNCHW, // 2 - V4R4R4XDLATOMICNHWC, // 3 - V4R4R5XDLATOMICNHWC, // 4 -}; - -template -void host_convolution_backward_weight(const Tensor& out, - const Tensor& in, - Tensor& wei, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - auto f_kcyx = [&](auto k, auto c, auto y, auto x) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(out(n, k, ho, wo)); - } - } - } - } - wei(k, c, y, x) = v; - }; - - auto f_kyxc = [&](auto k, auto y, auto x, auto c) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(out(n, ho, wo, k)); - } - } - } - } - wei(k, y, x, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_kcyx, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_kyxc, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 23) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - printf("additional: desired_grid_size\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t N = std::stoi(argv[7]); - const index_t K = std::stoi(argv[8]); - const index_t C = std::stoi(argv[9]); - const index_t Y = std::stoi(argv[10]); - const index_t X = std::stoi(argv[11]); - const index_t Hi = std::stoi(argv[12]); - const index_t Wi = std::stoi(argv[13]); - - const index_t conv_stride_h = std::stoi(argv[14]); - const index_t conv_stride_w = std::stoi(argv[15]); - const index_t conv_dilation_h = std::stoi(argv[16]); - const index_t conv_dilation_w = std::stoi(argv[17]); - const index_t in_left_pad_h = std::stoi(argv[18]); - const index_t in_left_pad_w = std::stoi(argv[19]); - const index_t in_right_pad_h = std::stoi(argv[20]); - const index_t in_right_pad_w = std::stoi(argv[21]); - - const index_t desired_grid_size = std::stoi(argv[22]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 7) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - constexpr auto N = Number<128>{}; - constexpr auto C = Number<128>{}; - constexpr auto Hi = Number<14>{}; - constexpr auto Wi = Number<14>{}; - constexpr auto K = Number<256>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - - constexpr auto conv_stride_h = I1; - constexpr auto conv_stride_w = I1; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 0 - using in_data_t = float; - using wei_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using out_data_t = half_t; - using acc_data_t = float; - using wei_data_t = float; -#elif 1 - using in_data_t = int8_t; - using out_data_t = int8_t; - using acc_data_t = int32_t; - using wei_data_t = int8_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - if(layout == ConvTensorLayout::NCHW) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - } - else if(layout == ConvTensorLayout::NHWC) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - } - else - { - std::runtime_error("wrong! not implemented"); - } - - Tensor in(in_lengths_host); - Tensor wei_device(wei_lengths_host); - Tensor wei_host(wei_lengths_host); - Tensor out(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); - out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_out = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - out.GenerateTensorValue(gen_out, num_thread); - } - - auto f_make_for_device_nchw = [&]() { - const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); - const auto wei_lengths_dev = make_tuple(K, C, Y, X); - const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - - auto f_make_for_device_nhwc = [&]() { - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - - // set zero to wei_device - wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread); -#if USE_CONV_WRW_V4R4R2_XDL_NCHW - if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R4_XDL_NHWC - if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW - if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw< - in_data_t, - wei_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - desired_grid_size, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC - if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk< - in_data_t, - wei_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - desired_grid_size, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC - if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk< - in_data_t, - wei_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - desired_grid_size, - nrepeat); - } -#endif - - if(do_verification) - { - host_convolution_backward_weight(out, - in, - wei_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - ck::utils::check_err(wei_device.mData, wei_host.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "out: ", out.mData, ",") << std::endl; - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl; - } - } -} diff --git a/library/src/obselete_driver_offline/gemm_driver_offline.cpp b/library/src/obselete_driver_offline/gemm_driver_offline.cpp deleted file mode 100644 index a09cb932d61d6e42004820114e0dca890c8f02e1..0000000000000000000000000000000000000000 --- a/library/src/obselete_driver_offline/gemm_driver_offline.cpp +++ /dev/null @@ -1,456 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdlops_mk_kn_mn.hpp" -#include "device_gemm_xdlops_mk_nk_mn.hpp" -#include "device_gemm_xdlops_km_kn_mn.hpp" -#include "device_gemm_xdlops_km_nk_mn.hpp" -#include "device_gemm_xdlops_mk_kn_nm.hpp" -#include "device_gemm_xdlops_mk_nk_nm.hpp" -#include "device_gemm_xdlops_km_kn_nm.hpp" -#include "device_gemm_xdlops_km_nk_nm.hpp" - -#define USE_GEMM_XDL_MK_KN_MN 1 -#define USE_GEMM_XDL_MK_NK_MN 1 -#define USE_GEMM_XDL_KM_KN_MN 1 -#define USE_GEMM_XDL_KM_NK_MN 1 -#define USE_GEMM_XDL_MK_KN_NM 0 -#define USE_GEMM_XDL_MK_NK_NM 0 -#define USE_GEMM_XDL_KM_KN_NM 0 -#define USE_GEMM_XDL_KM_NK_NM 0 - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM // 7 -}; - -enum struct GemmAlgo -{ - Xdl_MK_KN_MN, // 0 - Xdl_MK_NK_MN, // 1 - Xdl_KM_KN_MN, // 2 - Xdl_KM_NK_MN, // 3 - Xdl_MK_KN_NM, // 4 - Xdl_MK_NK_NM, // 5 - Xdl_KM_KN_NM, // 6 - Xdl_KM_NK_NM, // 7 -}; - -template -void host_gemm(const Tensor& a, - const Tensor& b, - Tensor& c, - const GemmMatrixLayout layout) -{ - if(layout == GemmMatrixLayout::MK_KN_MN) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(k, n)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_MN) - { - auto f_mk_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(n, k)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_MN) - { - auto f_km_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(k, n)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_MN) - { - auto f_km_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(n, k)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_KN_NM) - { - auto f_mk_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(k, n)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_NM) - { - auto f_mk_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(n, k)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_NM) - { - auto f_km_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(k, n)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_NM) - { - auto f_km_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(n, k)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} -int main(int argc, char* argv[]) -{ - using namespace ck; - - if(argc != 12) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: M, N, K\n"); - printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n"); - exit(1); - } - - const auto layout = static_cast(std::stoi(argv[1])); - const auto algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t M = std::stoi(argv[7]); - const index_t N = std::stoi(argv[8]); - const index_t K = std::stoi(argv[9]); - - debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]); - debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]); - -#if 0 - using ab_data_t = float; - using acc_data_t = float; - using c_data_t = float; -#elif 1 - using ab_data_t = half_t; - using acc_data_t = float; - using c_data_t = half_t; -#elif 1 - using ab_data_t = int8_t; - using acc_data_t = int32_t; - using c_data_t = int8_t; -#endif - - std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); - std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); - - // A - if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN || - layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM) - { - a_lengths_host[0] = static_cast(M); - a_lengths_host[1] = static_cast(K); - a_strides_host[0] = static_cast(K); - a_strides_host[1] = static_cast(1); - } - else - { - a_lengths_host[0] = static_cast(K); - a_lengths_host[1] = static_cast(M); - a_strides_host[0] = static_cast(M); - a_strides_host[1] = static_cast(1); - } - - // B - if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN || - layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM) - { - b_lengths_host[0] = static_cast(N); - b_lengths_host[1] = static_cast(K); - b_strides_host[0] = static_cast(K); - b_strides_host[1] = static_cast(1); - } - else - { - b_lengths_host[0] = static_cast(K); - b_lengths_host[1] = static_cast(N); - b_strides_host[0] = static_cast(N); - b_strides_host[1] = static_cast(1); - } - - // C - if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN || - layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN) - { - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); - } - else - { - c_lengths_host[0] = static_cast(N); - c_lengths_host[1] = static_cast(M); - c_strides_host[0] = static_cast(M); - c_strides_host[1] = static_cast(1); - } - - Tensor a(a_lengths_host, a_strides_host); - Tensor b(b_lengths_host, b_strides_host); - Tensor c_host(c_lengths_host, c_strides_host); - Tensor c_device(c_lengths_host, c_strides_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: "); - ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); - ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - } - -#if USE_GEMM_XDL_MK_KN_MN - if(algo == GemmAlgo::Xdl_MK_KN_MN) - { - if(layout != GemmMatrixLayout::MK_KN_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_kn_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_MK_NK_MN - if(algo == GemmAlgo::Xdl_MK_NK_MN) - { - if(layout != GemmMatrixLayout::MK_NK_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_nk_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_KN_MN - if(algo == GemmAlgo::Xdl_KM_KN_MN) - { - if(layout != GemmMatrixLayout::KM_KN_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_kn_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_NK_MN - if(algo == GemmAlgo::Xdl_KM_NK_MN) - { - if(layout != GemmMatrixLayout::KM_NK_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_nk_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_MK_KN_NM - if(algo == GemmAlgo::Xdl_MK_KN_NM) - { - if(layout != GemmMatrixLayout::MK_KN_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_kn_nm(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_MK_NK_NM - if(algo == GemmAlgo::Xdl_MK_NK_NM) - { - if(layout != GemmMatrixLayout::MK_NK_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_nk_nm(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_KN_NM - if(algo == GemmAlgo::Xdl_KM_KN_NM) - { - if(layout != GemmMatrixLayout::KM_KN_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_kn_nm(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_NK_NM - if(algo == GemmAlgo::Xdl_KM_NK_NM) - { - if(layout != GemmMatrixLayout::KM_NK_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_nk_nm(a, b, c_device, nrepeat); - } -#endif - - if(do_verification) - { - host_gemm(a, b, c_host, layout); - - ck::utils::check_err(c_device.mData, c_host.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; - } - } -} diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index b20a4b57e584f7ddb1b3ab897bca2099fa5feef0..f1ce23aae2beff81fec17d2a1b42f1224f0c1960 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -1,80 +1,69 @@ -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/include/ck - ${PROJECT_SOURCE_DIR}/include/ck/utility - ${PROJECT_SOURCE_DIR}/include/ck/host_utility - ${PROJECT_SOURCE_DIR}/include/ck/tensor_description - ${PROJECT_SOURCE_DIR}/include/ck/tensor - ${PROJECT_SOURCE_DIR}/include/ck/problem_transform - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host - ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance - ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce - ${PROJECT_SOURCE_DIR}/external/include/half -) - function(add_instance_library INSTANCE_NAME) message("adding instance ${INSTANCE_NAME}") - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) target_compile_features(${INSTANCE_NAME} PUBLIC) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) endfunction(add_instance_library INSTANCE_NAME) add_subdirectory(gemm) -add_subdirectory(gemm_bias2d) -add_subdirectory(gemm_bias_relu) -add_subdirectory(gemm_bias_relu_add) +add_subdirectory(gemm_splitk) +add_subdirectory(gemm_bilinear) +add_subdirectory(gemm_add_add_fastgelu) add_subdirectory(gemm_reduce) +add_subdirectory(gemm_bias_add_reduce) add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) +add_subdirectory(grouped_gemm) +add_subdirectory(contraction_scale) +add_subdirectory(contraction_bilinear) add_subdirectory(conv1d_fwd) add_subdirectory(conv2d_fwd) add_subdirectory(conv3d_fwd) add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu_add) -add_subdirectory(conv2d_fwd_bias_relu_atomic_add) add_subdirectory(conv2d_bwd_data) -add_subdirectory(reduce) add_subdirectory(convnd_bwd_data) -add_subdirectory(grouped_gemm) add_subdirectory(conv2d_bwd_weight) -add_subdirectory(batched_gemm_reduce) +add_subdirectory(convnd_bwd_weight) +add_subdirectory(reduce) +add_subdirectory(normalization) +add_subdirectory(elementwise) -add_library(device_operations STATIC - $ - $ - $ - $ - $ - $ - $ +add_library(device_operations STATIC $ - $ - $ - $ - $ - $ - $ - $ + $ + $ + $ + $ + $ $ + $ + $ + $ + $ + $ $ - device_conv2d.cpp + $ + $ + $ + $ + $ + $ + $ + $ + $ ) add_library(composablekernels::device_operations ALIAS device_operations) -set(DEV_OPS_INC_DIRS +set(DEV_OPS_INC_DIRS ${PROJECT_SOURCE_DIR}/include/ck/ ${PROJECT_SOURCE_DIR}/library/include/ck/ - ${PROJECT_SOURCE_DIR}/external/include/ ) + target_compile_features(device_operations PUBLIC) set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_include_directories(device_operations PUBLIC +target_include_directories(device_operations PUBLIC $ $ $ @@ -87,29 +76,25 @@ target_include_directories(device_operations PUBLIC $ $ $ - $ $ + $ $ - $ ) #once new arches are enabled make this an option on the main cmake file # and pass down here to be exported - -target_compile_options(device_operations -PRIVATE --offload-arch=gfx908 +target_compile_options(device_operations PRIVATE + --offload-arch=gfx908 + --offload-arch=gfx90a ) + # install(TARGETS device_operations LIBRARY DESTINATION lib) -install(TARGETS device_operations - EXPORT device_operationsTargets - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} -) -install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) -install(EXPORT device_operationsTargets - FILE composable_kerneldevice_operationsTargets.cmake +rocm_install(TARGETS device_operations + EXPORT device_operationsTargets) + +rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) +rocm_install(EXPORT device_operationsTargets + FILE composable_kerneldevice_operationsTargets.cmake NAMESPACE composable_kernel:: DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index 9641e3cf72d76ee52c69b832b47bfca8df879693..1cc92524c6bd9a9eb336a63bcfe3f98dcab4852f 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -23,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, - DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> // clang-format on >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp index c93c77dccced36ab67f458a115abaa2d24e3b878..c35a8d6d66bef8bdcf36df3a0fecfb451e63c424 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -39,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp index 8da334071a6cdf17248adde56a723e7961c8893c..1bbedebeb81f8fa9015313eae7272ed93edd5411 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -43,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp index 9566d5ecd4cac8dd05de288785706adbf139b703..2ceaa20b80b9a9b5fb20b7534c8ee38f8f95e4b2 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -44,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp index 3be808371340b6666444f2e2811b10a6d26b951d..3696285726a1736b5d1b75f3035df46fe5fb8580 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -39,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp index 21daf0b1931b8c5950f49e0d2bca83054f337b9e..f79d304187db7149e911893831caabb00d17bdb5 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -39,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp index 9606b1f0cc72170b5ad2bbbd4597bbb1fed1a377..8290e7565ccd3d5eab0100146fe0fafc79ec47fe 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -48,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp index 3d3e35e8e453a7c863303ea32eeb397ff6892316..f3345eba81e4fc2b2ba938013c49273fd4ab58ff 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -44,13 +49,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp index c6d6a1ba6a37c89cfc6dfa8dbca89dfe34818169..8b671dfdb4fa7567e85486178c7d44a0cd4b60ba 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -39,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp index 157bf413ac32d90e6236fde403f0858adaf82bc9..646450e722de0cb723c9ec8247074a2234220d90 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -39,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp index 5a8988722e216ad5bf06196c3f04981525b93447..1696d29713f42ac7bb54587ddfcc6bcc76f9b6e9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -39,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp index 2e892d97f51033d96ba1f0928090159ce82f18b4..3dbd63707d667087f4b7f5b175b07dd0c76f1cf5 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -44,13 +49,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp index 1f3951c938f201725a54b91cf5ea0d266b81ad75..0691f4f865b751d5293feb4b53dee93c012ad363 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -54,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp index d6faa5a9cb3113ae1fa24a346b42d1da167b9ce5..efd49bf12de2f6ec88f1b21911a1f5f5614e0f92 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -54,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp index b5bc2786f234f718e2253e46dc0631d444c23951..9c3d6609ca7221824c43122f8b018c2e4651a573 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -54,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp index 6858903ff48dca1aedb74880bb85c9fc7b8ee981..330d1396072d9066ecec5da719bfb013f1fe5a88 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_batched_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -46,13 +51,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); } -} // namespace device_batched_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index a15b5b7351792f9a8b8509ff4af070ff448aceec..1c4541afc5f375e76e3e6adf6c24fb543a36f477 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -1,18 +1,23 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,13 +26,13 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,43 +43,38 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp index a53cb8fc70bdca1c21884282c4387124aecfa4e6..07eb9b943c33d1bafe587b8eab741c95b70150fb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -1,18 +1,23 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,13 +26,13 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,43 +43,38 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp index ce929502cd8ca5a6a747d221d1b92eee55861a2a..2d9cee47d4b1663503d65a0bcee6d29491f41747 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -1,18 +1,23 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,13 +26,13 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,43 +43,38 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp index c709aa411c9e0d6e1cffdab17764b6422877337f..03ce1ce08b15e7aec81c3eef5d3f4637585b542c 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -1,18 +1,23 @@ -#include -#include "config.hpp" -#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,13 +26,13 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,40 +43,35 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb38c645ebaec5c4d7ffd6111e479c4d602df09c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_contraction_bilinear_instance +set(DEVICE_CONTRACTION_BILINEAR_INSTANCE_SOURCE + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +) + +add_library(device_contraction_bilinear_instance OBJECT ${DEVICE_CONTRACTION_BILINEAR_INSTANCE_SOURCE}) +set_target_properties(device_contraction_bilinear_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_contraction_bilinear_instance) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..036818ee2cc15bdf94998be1c54f4c9da456474d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b277fb86e8d586f6bd5542182e1fd30a8cb35c65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c03ce0b169edbfde8749602231831a67a1573d5c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab56c4c15988f41917281f364a188803dfda29b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..32806757a525a4fc488bd808d0f31eb474895ce3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_contraction_scale_instance +set(DEVICE_CONTRACTION_SCALE_INSTANCE_SOURCE + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +) + +add_library(device_contraction_scale_instance OBJECT ${DEVICE_CONTRACTION_SCALE_INSTANCE_SOURCE}) +set_target_properties(device_contraction_scale_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_contraction_scale_instance) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f49a98642f6454994079d3969024f9a7a052bb4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45ffa63ce28b9af947eee1771f010ac848dc1fca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc63b06a56e20bf6b676a30b3d6cb0f68b815321 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ce11f255a62072e816571fe348376b600ae58fcf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp index 9288e40e566a9eadbcdb691f58908bb28d56e809..d4c65ff54b092aafdd650d8e18351d1593a1edb1 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv1d_fwd_instance { +namespace instance { using F32 = float; using BF16 = bhalf_t; @@ -28,15 +34,12 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< -// clang-format off + // clang-format off //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if !CK_WORKAROUND_GITHUB_135 - // FIXME: this instance causes numerical errors. DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, -#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, @@ -106,7 +109,7 @@ void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances( device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv1d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp index 669dca617a053712ff65961930fafef5dc09cc30..166d25ba488bc09169f7beca28cd0b231d62305a 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv1d_fwd_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -103,7 +109,7 @@ void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances( device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv1d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp index 0abd47142ba2c2885fc14dcb5ad0ddade5422353..2cb296e47200e0382872bee0d8e3c810ffdd4c84 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv1d_fwd_instance { +namespace instance { using F32 = float; @@ -106,7 +112,7 @@ void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances( device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv1d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp index 53e0f7755022293d098ea67f9add1f7f80ba8a8f..2364c5ea327a56c29ea4747cb92a7e0fade68548 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv1d_fwd_instance { +namespace instance { using F32 = float; @@ -105,7 +111,7 @@ void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances( device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv1d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index b5814aa17fc4c6865c46ca47837d0b7552ac1f69..3b716d641c728b46be734dd2509fa4100241cf8e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -77,7 +82,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index 53498aff344f7ef149ca4781f88d3f1bb4914d96..5978ffcd10b7ff723d8aebfe8adf63910ea7b7cd 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -79,7 +84,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index fbe279e0333488312b621566df803ecb3cbfb867..42e80be1a0cd2c98367bdcb21fe94932ba6840a8 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F32 = float; @@ -76,7 +81,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 7fd51bbfbfb2edf8c3b5d27a86456764b1d47491..ff15c0238b3b4543d67cc3df1cac1934429e52c0 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using DataType = int8_t; using AccType = int32_t; @@ -77,7 +82,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt index 7c384a882b7b4bc650ccad523687f0d648671093..7d3c57b235e9314083917ab100db82080643239e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt @@ -3,9 +3,9 @@ set(DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; ) -add_library(device_conv2d_bwd_weight_instance OBJECT ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE}) +add_library(device_conv2d_bwd_weight_instance OBJECT ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE}) target_compile_features(device_conv2d_bwd_weight_instance PUBLIC) set_target_properties(device_conv2d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_bwd_weight_instance LIBRARY DESTINATION lib) +rocm_install(TARGETS device_conv2d_bwd_weight_instance) clang_tidy_check(device_conv2d_bwd_weight_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index d915db67587fa9b48a98a64fa022f12e9c8d2f12..ea9fb8c6a8bdc13ee247342694a088376b63ff47 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_weight_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -47,7 +52,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances{}); } -} // namespace device_conv2d_bwd_weight_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index e9f6636518d91befced1241d5b01aab6db545ae6..744f2f91e8b8b06d3db660b102e7d6ef3dd62618 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_weight_instance { +namespace instance { using F32 = float; @@ -46,7 +51,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances{}); } -} // namespace device_conv2d_bwd_weight_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 857e36d6f5783973563e8ab9211fff15c9bd2513..1ef4a9b07e15d1c9e4f14b87f5656b911285f654 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -6,7 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) +set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE}) + set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(device_conv2d_fwd_instance) +clang_tidy_check(device_convnd_2d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp index b2f6f9335eb9cac9c91203c3f696f8544fdb8216..7766a12eb9d39a46a2dd57aad38b2f38159eb30a 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -138,7 +143,7 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances{}); } -} // namespace device_conv2d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 47405ea1bfb21ac984ad0d235fea80ae15ee056f..efb4bd875fcfc85b09f5642fa6e740b1fec3afab 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -104,7 +109,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv2d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index a4060f8bf207c8fc34fa90c8614bc2dd7053e214..5c0110aa5105b67a1832328a702328f6bbc0fcd5 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -103,7 +108,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv2d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 3c46c2f7e98bb203072b3fc7fa9a93812d2d10b6..3e4c8debc906857d504663391f84a2a212979fc3 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_instance { +namespace instance { using F32 = float; @@ -102,7 +107,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv2d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 0db59ca394c8f30d8ddd25c38fd684309b806a4b..cd1bf085fb67564a8b2e9e2c8283665bb6b73f1c 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_instance { +namespace instance { using F32 = float; @@ -103,7 +108,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv2d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75351654baee07c27327bbc5b0fc64511a513492 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c274e7e49d9ee39e49c4bb4dd9853011a5d32073 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22cb76641537506a41cc2f43eaaef7cf98dd13c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..076faf7f3b79501777fe0e1da140817a9e345bc7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp index 9c3f0a4b964b22be3b1351967e24f485374b8639..ca0f9c81b162bf8a2541fe637764469bc34110a3 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_bias_activation_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -143,7 +148,7 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{}); } -} // namespace device_conv2d_fwd_bias_activation_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp index b9f46e26119b7116f076d1ffaa90afdb8b008806..91aa9182878c9ad83b05e9bda86e87fe70c0ede8 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_bias_activation_add_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -143,7 +148,7 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instan device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{}); } -} // namespace device_conv2d_fwd_bias_activation_add_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt deleted file mode 100644 index 5906c7c5ac7ad226c8eed054fbd30d446fc51164..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# device_conv2d_fwd_bias_relu_atomic_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE - device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) -set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - -clang_tidy_check(device_conv2d_fwd_bias_relu_atomic_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp deleted file mode 100644 index c56ad270aa4cec377fe1d17109fffe90e8ebcc1f..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv2d_fwd_bias_activation_atomic_add_instance { - -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; - -static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum::AtomicAdd; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< - // clang-format off - //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2> - // clang-format on - >; - -void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( - std::vector>& - instance_container) -{ - using Instances = - device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances; - - const auto instances = Instances{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Instance = remove_cvref_t(instances))>; - - auto instance = Instance{}; - - instance_container.push_back(std::make_unique(instance)); - }); -} - -} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 745d26904aa8f86ed38e49c8728f68e81607597e..e55a3d2b5b3b31e4520abcc5d5d1a8aeb9f07ae2 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv3d_fwd_instance { +namespace instance { using F32 = float; using BF16 = bhalf_t; @@ -28,15 +33,12 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< -// clang-format off + // clang-format off //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if !CK_WORKAROUND_GITHUB_135 - // FIXME: this instance causes numerical errors. DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, -#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, @@ -107,7 +109,7 @@ void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv3d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp index 4d51180e72567343ca465059f43252b42514b10f..01c6cc6b3789a556b7bbd4f15a132e905cd29a7a 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv3d_fwd_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -104,7 +109,7 @@ void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances( instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv3d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp index 9a8ff8d714335f287210ad07be114de0db212b92..f881958c91a1d3ff97f213392031ccc6a0ea06b8 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv3d_fwd_instance { +namespace instance { using F32 = float; @@ -103,7 +108,7 @@ void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances( instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv3d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp index 7f54b66f9b545fe395cd109a7792c5e8a9bbc3c2..d7c0a308746978d614ebd5cb84b4e07475e5caf0 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv3d_fwd_instance { +namespace instance { using F32 = float; @@ -106,7 +111,7 @@ void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances( instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv3d_fwd_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt index 037f8608086719ecacd8da7c29bff6ea51669162..dae633b7da84c7afe4f00eae1e01b4feffee4f86 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,5 @@ # device_convnd_bwd_data_instance -set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE +set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp; device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp; device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp; @@ -12,11 +12,11 @@ set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; -) +) add_library(device_convnd_bwd_data_instance OBJECT ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE}) target_compile_features(device_convnd_bwd_data_instance PUBLIC) set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib) +rocm_install(TARGETS device_convnd_bwd_data_instance) clang_tidy_check(device_convnd_bwd_data_instance) diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp index 5c915dcc426142a9137f907cdb2f853851fb0c0c..a449a9053f33110516e0cbe6bffd37e6b54e2a49 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using BF16 = bhalf_t; using F32 = float; @@ -78,7 +83,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp index e8f7d4f11ad1512c4de8de1bf6c17b122b4f71a8..fb976740325937b47e4b6850258efaf9becce09b 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -80,7 +85,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp index b4c65ab66ab85492a9cdf5dfcbced65e2f03f620..e8f2a45b717479ab934b2f579f07d7f11cfa2548 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F32 = float; @@ -77,7 +82,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp index e3958ef6891ce81f5244c5bea30826d92aae8d37..6aad1f029f5271d328e51910849a12c94e75fa7f 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using DataType = int8_t; using AccType = int32_t; @@ -80,7 +85,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 2e4cd5cf312021bc0aa377d9dce7762cc8a6b4b1..010291cb47b8a3ccefe1188d13f2d0c5e93a43e8 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using BF16 = bhalf_t; using F32 = float; @@ -78,7 +83,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index 7170decc439d885410b748a5323488a24709599e..e7e147177a244f78e633c1f09fc50751a32a92fc 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -33,13 +38,11 @@ using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, -#if 1 DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, -#endif DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, @@ -80,7 +83,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 5a727b1113a1a5cb6bf4bcc24fe6bf410fc7cf03..357ddabd108e327a62186f214155f8158369e119 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F32 = float; @@ -77,7 +82,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 3c53644ddc5bb88cb5c66376416181ca9b375648..3eadb0bdc9214b4da720b45b1b44033a23f22629 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using DataType = int8_t; using AccType = int32_t; @@ -32,10 +37,8 @@ using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, - #if 1 DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, - #endif DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, @@ -58,9 +61,7 @@ using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, - #if 1 DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, - #endif DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, @@ -82,7 +83,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index edbb7a14d9ece58855db6699494a259037d05104..6b5f71ff78ed42ec0116fde2181026e04db123d9 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using BF16 = bhalf_t; using F32 = float; @@ -78,7 +83,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp index 5d00fa8f08132422d9f06c8b84a9aea0de28537d..214aea289bd8bf8a0217784943d09664b32c8c5b 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -32,7 +37,6 @@ using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances = //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, -#if 1 DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, @@ -40,7 +44,6 @@ using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, -#endif DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, @@ -80,7 +83,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp index d5cd04de6b9060173c1ddf0d45dd2c1b6721c605..c3e8b5e8c7aa8c134ea030d5c6bc04b446c909c7 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using F32 = float; @@ -77,7 +82,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp index d5519706061545379e57b017f595e002411fff52..9142b8049b3a377cddd02969309c0331060126ab 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using DataType = int8_t; using AccType = int32_t; @@ -33,13 +38,11 @@ using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, -#if 1 DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, -#endif DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, @@ -80,7 +83,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); } -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7272163f2baee7a69f383bee8c65e54b7cd5b3b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/CMakeLists.txt @@ -0,0 +1,19 @@ +#device_convnd_bwd_weight_instance +set(DEVICE_CONVND_BWD_WEIGHT_INSTANCE_SOURCE + device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp; + device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; +) + +add_library(device_convnd_bwd_weight_instance OBJECT ${DEVICE_CONVND_BWD_WEIGHT_INSTANCE_SOURCE}) +target_compile_features(device_convnd_bwd_weight_instance PUBLIC) +set_target_properties(device_convnd_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +rocm_install(TARGETS device_convnd_bwd_weight_instance) + +clang_tidy_check(device_convnd_bwd_weight_instance) diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8aae435fee7353776ed16bf1dbae5dac1227a8f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e4964ce01124b675ccec61fb0c8dd68e4c0c784 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed25442dc4121953d14f508db72d0a3bf36402f1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a0dfeb6f4de7ebd8dd07d018107ffe109bac660 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..025c7c86d8cc1ede4bbe93706a530e9fdca1c39a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_default_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cde50d779b9939cc53476fa84b6df5b15470c684 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f32_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f32_default_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e2ad43a31520f1581cbb1c997e826a00b8d9518 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..647a8982422ea1c34f3f6c46ac576a2b40a48406 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f16_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f16_default_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40754a09f03cf456b37727c857443ced59b651c3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f32_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f32_default_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/device_conv2d.cpp b/library/src/tensor_operation_instance/gpu/device_conv2d.cpp deleted file mode 100644 index 6b99433ffa211cb4d385697fd994a6ffae743066..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/device_conv2d.cpp +++ /dev/null @@ -1,201 +0,0 @@ -#include -#include "config.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" -#include "host_interface.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv2d_fwd_instance { -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( - std::vector>& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( - std::vector>& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( - std::vector>& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( - std::vector>& instances); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( - std::vector>& instances); - -} // namespace device_conv2d_fwd_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl -{ - std::unique_ptr - MakeArgumentPointer(void* in_ptr, - void* wei_ptr, - void* out_ptr, - size_t N, - size_t K, - size_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) const - { - return el->MakeArgumentPointer(in_ptr, - wei_ptr, - out_ptr, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - } - std::unique_ptr MakeInvokerPointer() const - { - return el->MakeInvokerPointer(); - } - - std::string GetTypeString() { return el->GetTypeString(); } - bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg) - { - return el->IsSupportedArgument(arg); - } - - ck::tensor_operation::device::DeviceConvFwdPtr el; -}; - -DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr) {} -DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default; -DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default; -DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) - : pImpl(std::make_unique(std::move(other))) -{ -} - -std::unique_ptr -DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, - void* wei_ptr, - void* out_ptr, - size_t N, - size_t K, - size_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) const -{ - return pImpl->MakeArgumentPointer(in_ptr, - wei_ptr, - out_ptr, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); -} - -std::unique_ptr DeviceConvFwdPtr_t::MakeInvokerPointer() const -{ - return pImpl->MakeInvokerPointer(); -} - -std::string DeviceConvFwdPtr_t::GetTypeString() { return pImpl->GetTypeString(); } -bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr) -{ - return pImpl->IsSupportedArgument(arg_ptr); -} - -using namespace ck::tensor_operation::device::device_conv2d_fwd_instance; -void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t( - std::vector& instances) -{ - std::vector< - ck::tensor_operation::device::DeviceConvFwdPtr> - local_instances; - add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances); - for(auto& kinder : local_instances) - { - DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; - instances.emplace_back(tmp); - } - return; -} - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t( - std::vector& instances) -{ - std::vector< - ck::tensor_operation::device::DeviceConvFwdPtr> - local_instances; - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances); - for(auto& kinder : local_instances) - { - DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; - instances.emplace_back(tmp); // Perhaps we can do better - } - return; -} - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t( - std::vector& instances) -{ - std::vector< - ck::tensor_operation::device::DeviceConvFwdPtr> - local_instances; - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances); - for(auto& kinder : local_instances) - { - DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; - instances.emplace_back(tmp); // Perhaps we can do better - } - return; -} - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t( - std::vector& instances) -{ - std::vector< - ck::tensor_operation::device::DeviceConvFwdPtr> - local_instances; - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances); - for(auto& kinder : local_instances) - { - DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; - instances.emplace_back(tmp); // Perhaps we can do better - } - return; -} - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t( - std::vector& instances) -{ - std::vector< - ck::tensor_operation::device::DeviceConvFwdPtr> - local_instances; - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances); - for(auto& kinder : local_instances) - { - DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; - instances.emplace_back(tmp); - } - return; -} diff --git a/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..465ba4e9843f0b3d8780cbfb4ce89edfffec5d8d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt @@ -0,0 +1,10 @@ +set(DEVICE_ELEMENTWISE_INSTANCE_SOURCE + device_normalize_instance.cpp +) + +add_instance_library(device_elementwise_instance ${DEVICE_ELEMENTWISE_INSTANCE_SOURCE}) + +target_compile_features(device_elementwise_instance PUBLIC) +set_target_properties(device_elementwise_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_elementwise_instance) diff --git a/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12f7901c165fc45a5fbdaf9f4deffc4e6252599e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using inputType = F16; +using MeanType = F32; +using SquareMeanType = F32; +using GammaDataType = F16; +using BetaDataType = F16; +using outputType = F16; + +using Normalize = ck::tensor_operation::element_wise::Normalize; +using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple< + // clang-format off + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + Device5AryElementwise, + Device5AryElementwise, + Device5AryElementwise, + Device5AryElementwise + // clang-format on + >; + +void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 8de1920bb3db9bacb309f72040c2b2ae54338bd3..ce66b56a3e398a530804bbc554145e60e939d2f4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -28,14 +28,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp index db7f6af04b4c493351560e11a4e34f591c2ee55a..41efbdcc207801d1c3e0979fd7136d898b4f1963 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -34,12 +39,14 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< >; void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp index c4253bcc4cd2f79329e64604495a33dbfe48a71e..0e6d6239af9936e9fd2c2260979d380080f1b91d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -34,12 +39,14 @@ using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< >; void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp index d19d11f1f8a567d029233eafad04dd78a5d2ff09..bc2186e584674638c14948eec5110a722a1c6e94 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -34,12 +39,14 @@ using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp index cd86e5ceaedca27a2df506ad5730cf6bd0a5da65..e2000afb53884662f8ef60b55277724e90a38074 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -35,12 +40,14 @@ using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = >; void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp index 3fcc5fdfdcb997a8c6ac2facb5fa9177231c5f4b..267e3d76b9e227705598bf16995b148f8246197e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -34,12 +39,14 @@ using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< >; void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp index 8cd32128b556693ae0ea8106b0f9ec4f8853fc91..f8bb758b3d012130848e740a8a6f86d22a555173 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -35,12 +40,14 @@ using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = >; void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp index 4c4bfc440d646a4f4066b846745b7173f5755ceb..54bb6810ff2de0ced8eb5d6bca2693069e711809 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -35,12 +40,14 @@ using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = >; void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp index c6077341b1ce423087634a0c106c9e1b4fa2b605..1ce46ec7ecc30d2ba6e0e06baa77c98775aad6d9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -35,12 +40,14 @@ using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = >; void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp index 91b68d4bf23e868cca3825bac603d23d900806c9..f18adfee682c3c8214f30b225761726ec6616d54 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -23,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -31,12 +36,14 @@ using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< >; void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp index 13b185fd936e326183734829bc4905f7cb0c2375..91277b546a6bf44aaf236dac1d8b437058a15996 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -23,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -31,12 +36,14 @@ using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< >; void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp index ff4a89beb4d6bf1ba73c81f03590643dc6ecc217..a56d9d2c2f09215cf9c83c71450290ec6b30a41f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -23,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -31,12 +36,14 @@ using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp index e32158a292d4abbd37e18d9ec9a328ea99e51bfb..63794ac39c0ab20a30735590fd1f1c68b1fc2eab 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -23,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< // clang-format off // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> @@ -31,12 +36,14 @@ using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< >; void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index de97b60a62a7daccd267a5a65848075c5d35c2d9..16037f704ca17638a2a51809b29e811610715cb7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, @@ -46,13 +51,15 @@ using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tu >; void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp index 5e99c67b3f7f8dd3486f34e2088357bf14964d05..9ce9dc480a0c702696c7eeeec94d05d5060ac46c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, @@ -49,13 +54,15 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp index 321b97fd30e8a9b332a9b4474984ad80d1ac8114..83b01e2656c5a47f708b3e1360dec313ebf4d89c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, @@ -49,13 +54,15 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 1d69a23dd72ccb4336126dc02df488e014b94152..2a36451192af20b24264928f79661c2b2d4f90b4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, @@ -49,13 +54,15 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 8ffa2b8b8676a6e55223863e8787b9568b75e1df..938c99cb33b8e66b3ac78c233c7a92eeaf505cab 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using BF16 = ck::bhalf_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, @@ -46,13 +51,15 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 09adf1678d23389d419825a38ad4ec11ecee3213..7066be07f08bb470be6fae5e8679d75d6c180c3d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,13 +1,20 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +33,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, @@ -49,13 +56,15 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index 121b5857b2e4bf76c14186a235ca3a244fb0a0e2..39b2e73c2b42894340e1ecf6fac38bc685fcf37b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,13 +1,20 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +33,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, @@ -49,13 +56,15 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 2073d5f50ec4e12a26fbb828ff4bf3152ca609e2..b4b8cc3389105b8189767e8a44163ab33dd9e6d7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,13 +1,20 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +33,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, @@ -49,13 +56,15 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index e177ee60ec9ea0b6d99753a2eaac93d1426dc057..8f0996c351d430c843c0f51e7b78b53ec6a84e26 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,13 +1,20 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +33,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, @@ -46,13 +53,15 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp index ff830d416197881c9256022095d5c5aec0240a6c..5c7e7d3514e1266f9df3e9f96a7f2493c49d5c54 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -25,7 +30,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, @@ -48,13 +53,15 @@ using device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp index 79bca77aad19a49e0c2e61fe9ea0886c780f4be6..45ae6c51abec4458cf6a5124ff4269decba71bcb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -25,7 +30,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, @@ -48,13 +53,15 @@ using device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp index fac4e8d96ee843ed15c75f3c1245120dfca14a9b..455d786f041024f1cd166d870eaafb264c59f568 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -25,7 +30,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, @@ -48,13 +53,15 @@ using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp index ffcd957913e2411836a945ddafa1238f7f6287f6..5667bce36404683862a4534c071bfce96c61115e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -25,7 +30,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, @@ -45,13 +50,15 @@ using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index 2185b55aac09f386322028a1154e68264414d953..ee88c9a0b2b1737cd55c258fb2a07e13ea157715 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -26,7 +31,7 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, @@ -49,13 +54,15 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index 90966349b217777cc28cb82891b0f114d3838d43..35405578532618bc1e4a0bf4f9e188fc92aa3d0a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -26,7 +31,7 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, @@ -49,13 +54,15 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index aa5a13001c02af118b4f1d3a6b59e9ca73dcff73..a1090695498640af5f395035e9c3e7bc2e017689 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -26,7 +31,7 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, @@ -49,13 +54,15 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 82eec1164af964fdc83e6d2c107e406d8b337302..be8de8be5df2d5a3912b6be3c1ea955e2f2dba81 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F32 = float; @@ -26,7 +31,7 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, @@ -46,13 +51,15 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp index 08047c7e52b49b832f2a4ee1559d2234d3570f77..5fee5384711a10096fd4943779994846451f68fc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, @@ -42,12 +47,14 @@ using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp index 05cb080cbfde809904db545cbd1152c4cef5a087..4363bfe9271e46d95592fb07bcb2b0603c4cc744 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, @@ -42,12 +47,14 @@ using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 4de989caf0ca861702bd08d9354d9d0722405f09..544eb02f3e942afd0e0d26659b7b74171150c9d6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, @@ -51,12 +56,14 @@ using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 633e2aac2e4ade43e8158cc7075349c522762329..8ce8eb08152eb807dc9b83f113ea24667a4dd316 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -28,7 +33,7 @@ using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, @@ -52,7 +57,7 @@ using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< // clang-format off //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, @@ -61,14 +66,16 @@ using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = >; void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index 8284311102dcb475504cc2d80e49670d3417fa43..b99c023d61221cbe319d6523c9fc7f3b61771e16 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, @@ -42,12 +47,14 @@ using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = >; void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp index 235c4771f9e8a180b030ac64c3a908af3dace600..99a2383c706d9ad93d0e9cbe2258b1f6eb7b6909 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, @@ -42,12 +47,14 @@ using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = >; void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp index b7000bddf871426e4ede0ea92742174b824d4a85..8794275d34f3631e9085cfabe2f5efe6df212c83 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, @@ -42,12 +47,14 @@ using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = >; void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp index 1b4f23141b3befa160c0762c95daefb33b5fc9e2..4b62cec6089c36ce2fa0baf9a255a0d20d9e432d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, @@ -47,12 +52,14 @@ using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = >; void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp index fdc85dfc7109acb7c68167d95a3ef670f40eaf7c..a02763bca3099005aa79db8acfdb0893624e2836 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F64 = double; @@ -26,7 +31,7 @@ using device_gemm_xdl_f64_f64_f64_km_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, @@ -38,12 +43,14 @@ using device_gemm_xdl_f64_f64_f64_km_kn_mn_instances = >; void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp index e400cd9bbba98f649d0b3ff9f438a43f94df5e67..1275197feab36d5dd3df4fe39923a2977af1521d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F64 = double; @@ -26,7 +31,7 @@ using device_gemm_xdl_f64_f64_f64_km_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, @@ -38,12 +43,14 @@ using device_gemm_xdl_f64_f64_f64_km_nk_mn_instances = >; void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp index 2f9241b93b3ebb4b6eff775e76dd1913bb33e2d4..d763c68f9e53c782678ffca0677a55346ef91297 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F64 = double; @@ -26,7 +31,7 @@ using device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, @@ -38,12 +43,14 @@ using device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances = >; void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp index 537fe2bdae7a50bac58b830eaa60d8be74e11e2c..e52e3ff61b3d1560fa675ca900bd944e0cc54b2f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F64 = double; @@ -26,7 +31,7 @@ using device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, @@ -43,12 +48,14 @@ using device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances = >; void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..789c5b628f1bee3e6df79d93a1598ad11807e221 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,14 @@ +# device_gemm_add_add_fastgelu_instance +set(DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; +) + +add_library(device_gemm_add_add_fastgelu_instance OBJECT ${DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_add_add_fastgelu_instance PUBLIC) +set_target_properties(device_gemm_add_add_fastgelu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_add_add_fastgelu_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e00a66c5dfe7fd690735ac9d2d837408660041af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5f398937a0a2c3d2242bb800c36ab04e596ef60 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e2b5cf6699ee1696724b925a35280b14c3f4ead --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e28889a29d8acf78ba54bcfcc0e433beb4dabd8f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt deleted file mode 100644 index e2b0abb1d105930b15ab57ffcb17c060981cc87c..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# device_gemm_bias2d_instance -set(DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE - device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp; -) - -add_library(device_gemm_bias2d_instance OBJECT ${DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE}) -set_target_properties(device_gemm_bias2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - -clang_tidy_check(device_gemm_bias2d_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index bd16850ee4f1069331b3a86ac72ede2ec92be71d..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index 12740ce256f689cf54cc35d7ba2508f1c927f1be..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 56db0475efe8183127bdb2c8d71afd334239554a..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index b20ee8db69a8550adeca206d52c324bf28190448..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp deleted file mode 100644 index 11984c36db58ba39115b4f32e46dbccd0b35518a..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp deleted file mode 100644 index bd0a9880594f7e725dcdee49a6129d479800be37..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp deleted file mode 100644 index 440ea1582e57cac71977e28e66b84692141ea298..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp deleted file mode 100644 index fab885969f79c4dbf2b472f42c303129a4a41ee7..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances = std::tuple< - // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, - DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..85a7f3f06183e0f7cbf1f386540db19492095cb3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -0,0 +1,13 @@ +set(DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) + +add_library(device_gemm_bias_add_reduce_instance OBJECT ${DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_bias_add_reduce_instance PUBLIC) +set_target_properties(device_gemm_bias_add_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bias_add_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aec29f2aa17be2f5e7615246344b59f4bae32c42 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ab8e707884f9776d63845dabfb31705bb7ed1e1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..31377ef8286bfdc3ef1ba5c36aac2bd68b1b89e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d313fc367d5e0dfb6425f56d760030d2b011ca02 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt deleted file mode 100644 index e2e7d4badd2278ff519b1b59abd5a6c04183f0a5..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -# device_gemm_bias_relu_instance -set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE - device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; -) - -add_library(device_gemm_bias_relu_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) -set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - -clang_tidy_check(device_gemm_bias_relu_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index 4927a05ca4eb81d789083f1c4881364357fa518c..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; - -// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) -using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off - //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index f712f9de11802911400723126029674b15535157..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; - -// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) -using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off - //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 26af05bbde4c3368de3206cf8e194dc04b58de2e..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; - -// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) -using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index 901b7a5d644c3623309ba55899bf7e93bf5c9825..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; - -// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) -using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt deleted file mode 100644 index a10dbb555dc7b3c928451a7784dcc11bc6c4d661..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -# device_gemm_bias_relu_add_instance -set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE - device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; -) - -add_library(device_gemm_bias_relu_add_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) -set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - -clang_tidy_check(device_gemm_bias_relu_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index c26f66a9ed5fe3328ac0de30ed3e5897bcc82fcc..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; - -// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) + c1[m, n] -using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off - //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index c0950666b171c782c1264f1ab586892f5b48838d..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; - -// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) + c1[m, n] -using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off - //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 42c1f72d6e633fbffbc2bc6c3b84ec39376b8561..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; - -// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) + c1[m, n] -using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index 3961def81d3d906bdc58f19a7bee95d3026f295b..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; - -// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) + c1[m, n] -using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> - // clang-format on - >; - -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e6c93da88c83244b018c685be701fa4a717ac5f4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_gemm_bilinear_instance +set(DEVICE_GEMM_BILINEAR_INSTANCE_SOURCE + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; +) + +add_library(device_gemm_bilinear_instance OBJECT ${DEVICE_GEMM_BILINEAR_INSTANCE_SOURCE}) +set_target_properties(device_gemm_bilinear_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bilinear_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b8777a4241c6d88115e9f09013f72da8a1092ed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K Padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..589e4bf6d19191503aa0d4b394d35ffd56c7d446 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K Padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d18b7c26681f271f1d10dc9dbff9c903a7a4ba47 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29763ea4a2010c0f6181146d3f63ed4c723c80b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // M/N/N padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 5bc6d17a93ae513e57b5c5892f63f8b6cb42961b..5fbdc28d7b699eb5a9c9da366f156c8a496837d0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -6,5 +6,5 @@ set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE ) add_instance_library(device_gemm_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE}) -install(TARGETS device_gemm_reduce_instance LIBRARY DESTINATION lib) +rocm_install(TARGETS device_gemm_reduce_instance) clang_tidy_check(device_gemm_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index 83ed803f5e177ced6e58dde18e1f958d10b6c494..f32303dbe0d09605458a1b61405c303513f4c444 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -1,18 +1,24 @@ -#include -#include "config.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,14 +27,14 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Div = ck::tensor_operation::element_wise::UnaryIdentic; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,42 +44,37 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[k, m] * b[k, n] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index cf73afde1d3b81bf233e188faebf55514ddf43f1..82acbccea65c04cf8cffd8bf9e19da11741559c4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -1,18 +1,24 @@ -#include -#include "config.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,14 +27,14 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Div = ck::tensor_operation::element_wise::UnaryIdentic; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,42 +44,37 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[k, m] * b[n, k] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index a8f7dccb4d9b86ae9b1891320c3adc03b96834a2..978a4cb353a9a26c5be1318fbdb19a92d4664cc5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -1,18 +1,24 @@ -#include -#include "config.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,14 +27,14 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Div = ck::tensor_operation::element_wise::UnaryIdentic; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,42 +44,37 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[m, k] * b[n, k] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index 63bc293aa43519f4915ad6de5517c2a2a449ffea..a067449f4cb5ab426f58ca5d8e3ec3029fff0040 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -1,18 +1,24 @@ -#include -#include "config.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { -using F16 = ck::half_t; -using F32 = float; -using DPtrsGlobal = ck::Tuple; +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,14 +27,14 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; +using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; -using Div = ck::tensor_operation::element_wise::UnaryIdentic; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -38,39 +44,34 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[m, k] * b[n, k] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3700ddf19d4f5eda4f7d1f638a9828a0f4719d29 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -0,0 +1,15 @@ +set(DEVICE_GEMM_SPLITK_INSTANCE_SOURCE + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_splitk_instance OBJECT ${DEVICE_GEMM_SPLITK_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_splitk_instance PUBLIC) +set_target_properties(device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp similarity index 90% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp index 26ec965bb50c58a5947301b2341a22409759fe90..da59b91f0e672515a3c8e042479c7bf74bb174a4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk_c_shuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, @@ -41,13 +47,15 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp similarity index 90% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp index 45e3f9f9400d36bd97e2f3d1feea8e65d80acf23..aa65e1343333ea9f091fda0dff367cbcd2af809b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk_c_shuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, @@ -41,13 +47,15 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp similarity index 90% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 042ac2b8cae4b409ab8f76354e59de9edc31bfca..32b229c6cbef76dc8c2ec69f87fa77981cb37b6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk_c_shuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, @@ -41,13 +47,15 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp similarity index 74% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 21fdb7cd9df7d0d0e3a1f1fb1234249b623854a6..004143afe5c3d5ade861f745fb5c63248ca3a109 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk_c_shuffle.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, @@ -45,50 +51,16 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format on >; -// using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< -// // clang-format off -// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| -// B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| -// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| -// ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| -// BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| -// CBlockTransferClusterLengths| CBlockTransfer| -// //#########################| Type| Type| Type| Type| | | | -// Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | -// XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| -// SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| -// SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| -// _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -// //#########################| | | | | | | | -// Operation| Operation| Operation| | | | | | | | -// | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| -// PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | -// PerVector| PerVector_K1| | PerShuffle| PerShuffle| -// _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -// //#########################| | | | | | | | | | -// | | | | | | | | | | | | -// | | | | | | | | | | | | -// | | | | | -// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, -// PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, -// 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, -// true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, -// true, 1, 9, S<1, 2, 1, 72>, 2> -// // clang-format on -// >; - void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); - - // FIXME - IsSupportedArgument() is false, need to check validity - // add_device_operation_instances( - // instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp similarity index 89% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp index 971bdcad5832d0f52920d047db272e8bddd6ad8b..051ff652b94e44f5cb9e2a14b7b4680c5d3fe2f3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, @@ -41,13 +47,15 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp similarity index 89% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp index 3b7bdb87be0c65f17e65a8c174e408f049ad6df1..5d3cbf896b89dc337ee361e2d7f89b9153a3e6a1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, @@ -41,13 +47,15 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp similarity index 92% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp index 8366616246e566bf7a63a8944d6b24021ab7502e..9a9b05a32634ab453e7c1a1882a86837ca0b80f8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off //###################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //###################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //###################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, @@ -46,13 +52,15 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp similarity index 92% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp index 396de62cfb2696e8b800f688d0d316827e72bfc9..50dc93051d1e39cb0661aca5ec16d9b303540be3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,13 +1,19 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl_splitk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +32,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, @@ -46,13 +52,15 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector>>& + instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); } -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 6c5e31fddd37fd4051299c243e8c9a553b67593c..4d1115ceb6457bafec42d4b5c028bc1f71010225 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -6,10 +6,10 @@ set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; ) -add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) +add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) target_compile_features(device_grouped_gemm_instance PUBLIC) set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_grouped_gemm_instance LIBRARY DESTINATION lib) +rocm_install(TARGETS device_grouped_gemm_instance) clang_tidy_check(device_grouped_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp index 19f1011c3f1a71a83e24ef1335e3d5117235d898..ebc4cc952b3c8cda6932a713ffe81158ef02bcf5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_grouped_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_grouped_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, @@ -47,7 +52,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); } -} // namespace device_grouped_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp index 59e0d240555187f8c18490925dfa2a8ebf94bdb7..e604f15e2364e9b8cfc9553a42e64d5ffc9da527 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_grouped_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_grouped_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, @@ -47,7 +52,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); } -} // namespace device_grouped_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 35052ae8a93d29e2b41a9094655252df0f969f28..1b7ecb588486fe902e9b87da2d9ffb7bf7de6335 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_grouped_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_grouped_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -26,7 +31,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, @@ -56,7 +61,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); } -} // namespace device_grouped_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index cb41d2724c4d02331ab5dcd3f37d050829ceae03..65c88817f4abca6f148968659f9b513048c56f35 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,13 +1,18 @@ -#include -#include "config.hpp" -#include "device_grouped_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_grouped_gemm_instance { +namespace instance { using F16 = ck::half_t; using F32 = float; @@ -27,7 +32,7 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, @@ -50,7 +55,7 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< // clang-format off //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, @@ -67,7 +72,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); } -} // namespace device_grouped_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a6ae07bab9cdc2e05f6390d3cfdd2b96fa057918 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -0,0 +1,10 @@ +# device_normalization_instance +set(DEVICE_NORMALIZATION_INSTANCE_SOURCE + device_softmax_f32_f32_instance.cpp + device_softmax_f16_f16_instance.cpp +) + +add_library(device_normalization_instance OBJECT ${DEVICE_NORMALIZATION_INSTANCE_SOURCE}) +set_target_properties(device_normalization_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_normalization_instance) diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8465baa17cdaf4e13a2fc4bad28b7cd58b1a737b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_softmax_f16_f16_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + DeviceSoftmax, // fallback kernel + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax + // clang-format on + >; + +void add_device_softmax_f16_f16_rank3_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 2>{}); +} + +void add_device_softmax_f16_f16_rank4_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73ecf747b2728873dd461bf704f784b65a3f3df0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using device_softmax_f32_f32_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + DeviceSoftmax, // fallback kernel + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax + // clang-format on + >; + +void add_device_softmax_f32_f32_rank3_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 2>{}); +} + +void add_device_softmax_f32_f32_rank4_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp index 0274d89fc9eb6dfa522b846b51443b4647572cb9..c97efbc901a256280c29a9923d9a23a2a00efcdc 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -46,7 +49,7 @@ ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp index 8a43d860ea733213551eb3dc54d077b182ad333b..5e73b3d8b94f5842ffb3c60dba66b4288e9c3c77 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -33,7 +36,7 @@ ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp index 3e0b8ba59c7eb6b70c968c5134ebf7308c21d618..93d3e27016a6a8c914ca97f52c83d0c35369b43a 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -21,7 +24,7 @@ ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp index ee96311f8ce1cf1f8cd0bca74f2e8b533703cd48..38800ddde5a55053d57fddcd648d5bc842aafecc 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -45,7 +48,7 @@ ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp index b0ae95e82d9af3d22bf9325cbf3bacf603d678cf..b821aeee0adf4a4dd63f0754c9efab5bca632bb2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -21,8 +24,7 @@ ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp index 9cca2dbbeb9ae51eb2861de2356bd98457ea9e93..074d0cfdf7bd8203ae6f005ea4ac6d22f14673a7 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -45,7 +48,7 @@ ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp index 05cd1921ee70536f0d80e46e23e1d4a950c2b534..e803fb842d2ffb0430f2383c075fb97e3e81f1ed 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -17,7 +20,7 @@ ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp index 66ef01786438507728e95a61a3544ec89ca18ed5..4bf4139d28dca0e0bb6146bf792ec72797390f2c 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_blockwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -33,7 +36,7 @@ ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp index 9b2b7f5d8c1f8aa16df6a887235686211ae2467a..a571655cdcf98cb77d10b038bb779f23d5995537 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_multiblock_atomic_add.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -17,8 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp index fc956aa04b645d4561cd6cacb360f8018ab72fb6..9ad9a630bd8b32f00238e933e505755f66b3a499 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_multiblock_atomic_add.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -17,7 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp index e5ffd9f976d577ca05c4f3dd17ad6132726e4f3f..4ee70702c06e6e76f5671641f93e12775dc4c5f6 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_multiblock_atomic_add.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -17,8 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp index 229829b88971134bce127569672979df3fb0b35f..8c5fa80e8140dafc3fbfef20641e4173738aacb2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_multiblock_atomic_add.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -17,8 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp index 497f2695be00b89456a225aaa4c01fbfa9f480c1..d2b81c486d968b9a542ad523a2961152b07f3060 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_multiblock_atomic_add.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -17,7 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp index 02fc4b4c01a8ef9b0dccd46ba42db8c655652280..8d678e784ae2fd9a15a303841441d9ecbf612bad 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -46,7 +49,7 @@ ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp index 0984cdc46b9a5fdf15ea36344b852a9eb7feabba..010560586a659876177d1fa81e6ebb06492073c0 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -33,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp index 64f14bd4e72c8aa321e61d03ab4e8976b3abb87f..55c53dfd586b8387d13c7950a3c1882bef1266cc 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -21,8 +24,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp index 69ed303b177d47b03cb3afa8b5d62dc23599104e..367cf9a65d48c22fe337bd5e5ce123a76c3abe52 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -45,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp index 5d791cec41034fa0eba0427c9c3cc856b08dbf5c..18fd08448ccfaba4501cec33f392e7319a26969e 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -21,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp index 16c0409134ac0b07d4875acc30c6ffc98de2d439..3d02f3cbe3052f2a64da5505cf83279e0bf77617 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -45,8 +48,7 @@ ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation - } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp index 7af7bc03f28587fa77d71ec5026f9341bbe6f46c..fcf072a0864f2d25668a7d3141991ad116d42e46 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -18,7 +21,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); // clang-format on // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp index 9580aae057d410e493fa48341d3b1b8f539d76a8..85d7ce8b4c9e1628d76e05f43827d979d98ad3e6 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp @@ -1,9 +1,12 @@ -#include "device_reduce_instance_threadwise.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { // clang-format off // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim @@ -33,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); // clang-format on -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/utility/CMakeLists.txt b/library/src/utility/CMakeLists.txt index 0914855d59fecfc29b21c8c6bc07976573bb7b99..afa6de5119634753911776529f63b06ce28bba8e 100644 --- a/library/src/utility/CMakeLists.txt +++ b/library/src/utility/CMakeLists.txt @@ -1,13 +1,3 @@ -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/include/ck - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element - ${PROJECT_SOURCE_DIR}/include/ck/utility - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility -) - set(CONV_UTIL_SOURCE conv_util.cpp ) diff --git a/library/src/utility/conv_util.cpp b/library/src/utility/conv_util.cpp index a60d1a34952f29d9647558d680ffe72bd95fa5e8..3a223770cddfd6472f135bdeb8f789769b5e6dfd 100644 --- a/library/src/utility/conv_util.cpp +++ b/library/src/utility/conv_util.cpp @@ -1,5 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "conv_util.hpp" +#include "ck/library/utility/conv_util.hpp" namespace ck { namespace utils { diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index c763127a6839d961ce921662272097e11473ef17..aa0cde57950d03196a05e062f21e133b321a5620 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -1,66 +1,50 @@ include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/include/ck - ${PROJECT_SOURCE_DIR}/include/ck/utility - ${PROJECT_SOURCE_DIR}/include/ck/host_utility - ${PROJECT_SOURCE_DIR}/include/ck/tensor_description - ${PROJECT_SOURCE_DIR}/include/ck/tensor - ${PROJECT_SOURCE_DIR}/include/ck/problem_transform - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor - ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance - ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/external/include/half + ${PROJECT_SOURCE_DIR}/ ) # ck_profiler set(PROFILER_SOURCE src/profiler.cpp src/profile_gemm.cpp - src/profile_gemm_bias_2d.cpp - src/profile_gemm_bias_relu.cpp - src/profile_gemm_bias_relu_add.cpp + src/profile_gemm_splitk.cpp + src/profile_gemm_bilinear.cpp + src/profile_gemm_bias_add_reduce.cpp + src/profile_gemm_add_add_fastgelu.cpp src/profile_gemm_reduce.cpp src/profile_batched_gemm.cpp + src/profile_batched_gemm_reduce.cpp + src/profile_grouped_gemm.cpp src/profile_conv_fwd_bias_relu.cpp src/profile_conv_fwd_bias_relu_add.cpp - src/profile_conv_fwd_bias_relu_atomic_add.cpp src/profile_convnd_fwd.cpp src/profile_conv_fwd_cpu.cpp src/profile_convnd_bwd_data.cpp - src/profile_reduce.cpp - src/profile_grouped_gemm.cpp src/profile_conv_bwd_weight.cpp - src/profile_batched_gemm_reduce.cpp + src/profile_convnd_bwd_weight.cpp + src/profile_reduce.cpp + src/profile_normalization.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE conv_util) -target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bilinear_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_cpu_instance) target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) -target_link_libraries(ckProfiler PRIVATE device_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) -target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_weight_instance) +target_link_libraries(ckProfiler PRIVATE device_normalization_instance) +target_link_libraries(ckProfiler PRIVATE device_reduce_instance) diff --git a/include/ck/utility/data_type_enum.hpp b/profiler/include/data_type_enum.hpp similarity index 63% rename from include/ck/utility/data_type_enum.hpp rename to profiler/include/data_type_enum.hpp index fda6a2b05cf0d0c514f630d0e13bd501784b3654..afcd6fea224f3e06f29309349bcc6ad55caa7d4f 100644 --- a/include/ck/utility/data_type_enum.hpp +++ b/profiler/include/data_type_enum.hpp @@ -1,5 +1,7 @@ -#ifndef CK_DATA_TYPE_ENUM_HPP -#define CK_DATA_TYPE_ENUM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { @@ -16,4 +18,3 @@ enum struct DataTypeEnum }; } // namespace ck -#endif diff --git a/include/ck/utility/data_type_enum_helper.hpp b/profiler/include/data_type_enum_helper.hpp similarity index 86% rename from include/ck/utility/data_type_enum_helper.hpp rename to profiler/include/data_type_enum_helper.hpp index 9c8e01a7e38b8c1f44750969d807b7ba517c5693..6f8ef2b9f75f39ce26bc0758c62492bf7fdcef2f 100644 --- a/include/ck/utility/data_type_enum_helper.hpp +++ b/profiler/include/data_type_enum_helper.hpp @@ -1,8 +1,10 @@ -#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP -#define CK_DATA_TYPE_ENUM_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "data_type.hpp" -#include "data_type_enum.hpp" +#pragma + +#include "ck/utility/data_type.hpp" +#include "profiler/include/data_type_enum.hpp" namespace ck { @@ -73,4 +75,3 @@ struct get_datatype_enum_from_type }; } // namespace ck -#endif diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index 3393110c33ea1d9c648bf9d860f5b68bb2719161..0da9a26cf55da62b841e6496be5fcdd74fd70105 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -1,55 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include -#include "check_err.hpp" -#include "config.hpp" -#include "element_wise_operation.hpp" -#include "tensor_layout.hpp" -#include "device.hpp" -#include "host_tensor_generator.hpp" -#include "device_gemm.hpp" -#include "reference_batched_gemm.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_batched_gemm_instance { - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( - std::vector&); - -} // namespace device_batched_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" namespace ck { namespace profiler { @@ -67,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification, int M, int N, int K, + int BatchStrideA, + int BatchStrideB, + int BatchStrideC, int StrideA, int StrideB, int StrideC, @@ -78,46 +48,44 @@ bool profile_batched_gemm_impl(int do_verification, std::size_t row, std::size_t col, std::size_t stride, + std::size_t batch_stride, auto layout) { if(is_same::value) { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({row * stride, stride, 1})); + std::vector({batch_stride, stride, 1})); } else { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({col * stride, 1, stride})); + std::vector({batch_stride, 1, stride})); } }; - Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); - Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{})); Tensor c_g_m_n_host_result( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); Tensor c_g_m_n_device_result( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - std::unique_ptr> c_f32_g_m_n_host_result = nullptr; - std::unique_ptr> c_f32_g_m_n_device_result = nullptr; + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; - std::size_t num_thread = 1; switch(init_method) { case 0: break; case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - // set zero to c_device_buf - c_g_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -129,56 +97,21 @@ bool profile_batched_gemm_impl(int do_verification, if(do_verification) { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - Tensor a_f32_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); - Tensor b_f32_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); - c_f32_g_m_n_host_result = std::make_unique>( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - c_f32_g_m_n_device_result = std::make_unique>( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - - bf16_to_f32_(a_g_m_k, a_f32_g_m_k); - bf16_to_f32_(b_g_k_n, b_f32_g_k_n); - - using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: - ReferenceBatchedGemm; - - auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); - - auto ref_argument = ref_batched_gemm.MakeArgument(a_f32_g_m_k, - b_f32_g_k_n, - *c_f32_g_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - } - else - { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; - using ReferenceBatchedGemmInstance = - ck::tensor_operation::host::ReferenceBatchedGemm; + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); - auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); - auto ref_argument = ref_batched_gemm.MakeArgument( - a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - } + ref_invoker.Run(ref_argument); } DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); @@ -189,172 +122,56 @@ bool profile_batched_gemm_impl(int do_verification, b_device_buf.ToDevice(b_g_k_n.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); - // add device GEMM instances - std::vector - gemm_ptrs; + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(gemm_ptrs); - } - } + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); - if(gemm_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - std::string best_gemm_name; + std::string best_op_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; - // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) + // profile device op instances + for(auto& op_ptr : op_ptrs) { auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - BatchCount); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string gemm_name = gemm_ptr->GetTypeString(); + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -370,11 +187,11 @@ bool profile_batched_gemm_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; + << " GB/s, " << op_name << std::endl; if(tflops > best_tflops) { - best_gemm_name = gemm_name; + best_op_name = op_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; @@ -384,20 +201,8 @@ bool profile_batched_gemm_impl(int do_verification, { c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - - bf16_to_f32_(c_g_m_n_device_result, *c_f32_g_m_n_device_result); - float err = check_error(*c_f32_g_m_n_host_result, *c_f32_g_m_n_device_result); - pass = pass && (err < 1E-6); - } - else - { - float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result); - pass = pass && (err < 1E-6); - } + pass = pass & + ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData); if(do_log) { @@ -413,13 +218,12 @@ bool profile_batched_gemm_impl(int do_verification, } else { - std::cout << "this device GEMM instance does not support this GEMM problem" - << std::endl; + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; } } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; return pass; } diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp index 7ba04726864e0b266bf7efb0e134f2b1104b5533..b7dc979577c088039fe7024616d4e568a755823d 100644 --- a/profiler/include/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -1,37 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_gemm_reduce.hpp" -#include "reference_batched_gemm.hpp" +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { - -using F32 = float; -using F16 = ck::half_t; -using DPtrsGlobal = ck::Tuple; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; - -using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< - DPtrsGlobal, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - DInElementOps, - DOutElementOps>; +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( std::vector&); @@ -45,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( std::vector&); -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -56,7 +55,7 @@ namespace profiler { template @@ -96,16 +95,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification, Tensor c_g_m_n_host_result( f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); - Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); Tensor c_g_m_n_device_result( f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); - Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( {static_cast(BatchCount), static_cast(M)}))); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; @@ -129,25 +128,26 @@ bool profile_batched_gemm_reduce_impl(int do_verification, b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using D0ReduceOp = ck::reduce::Add; - using D1ReduceOp = ck::reduce::Add; - using UnaryIdenticElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; - using UnarySquareElementOp = - ck::tensor_operation::element_wise::UnarySquare; - using DxsInElementOps = ck::Tuple; - using DxsOutElementOps = ck::Tuple; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - const auto dxs_in_element_op = DxsInElementOps{}; - const auto dxs_out_element_op = DxsOutElementOps{}; - const auto d0_reduce_op = D0ReduceOp{}; - const auto d1_reduce_op = D1ReduceOp{}; + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&passthrough, &passthrough}; if(do_verification) { @@ -159,6 +159,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, BElementOp, CElementOp>; + using ReduceAccDataType = ReduceDataType; + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_invoker = ref_batched_gemm.MakeInvoker(); @@ -171,21 +173,22 @@ bool profile_batched_gemm_reduce_impl(int do_verification, { for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetIdentityValue(); - float d1_acc = d1_reduce_op.GetIdentityValue(); + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { - float d0_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); - float d1_val; + ReduceAccDataType d0_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d1_val; - UnarySquareElementOp{}(d1_val, d0_val); - d0_reduce_op(d0_acc, d0_val); - d1_reduce_op(d1_acc, d1_val); + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); } - d0_g_m_host_result(batch, m) = ck::type_convert(d0_acc); - d1_g_m_host_result(batch, m) = ck::type_convert(d1_acc); + d0_g_m_host_result(batch, m) = ck::type_convert(reduce0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(reduce1_acc); } } } @@ -193,18 +196,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification, DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); - DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + d1_g_m_device_result.mDesc.GetElementSpace()); - auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer())); + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; a_device_buf.ToDevice(a_g_m_k.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data()); // add device GEMM instances - std::vector - gemm_ptrs; + std::vector gemm_ptrs; if constexpr(is_same::value && is_same::value && is_same::value) @@ -213,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( gemm_ptrs); } @@ -221,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( gemm_ptrs); } @@ -229,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( gemm_ptrs); } @@ -237,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( gemm_ptrs); } @@ -256,31 +260,32 @@ bool profile_batched_gemm_reduce_impl(int do_verification, // profile device GEMM instances for(auto& gemm_ptr : gemm_ptrs) { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - dxs_global, - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - dxs_in_element_op, - dxs_out_element_op, - BatchCount); + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops, + BatchCount); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -310,8 +315,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, if(do_verification) { c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); - d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); - d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); diff --git a/profiler/include/profile_conv_bwd_weight_impl.hpp b/profiler/include/profile_conv_bwd_weight_impl.hpp index 8e3a4074b081a1d4318f28665897a8be38b1d9b9..9820d978fd0afd53e8ba0ff08762f4052127c9da 100644 --- a/profiler/include/profile_conv_bwd_weight_impl.hpp +++ b/profiler/include/profile_conv_bwd_weight_impl.hpp @@ -1,20 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "stream_config.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "device_conv_backward_weight.hpp" -#include "element_wise_operation.hpp" -#include "reference_conv_backward_weight.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_weight_instance { +namespace instance { using DeviceConvBwdWeightNoOpPtr = DeviceConvBwdWeightPtr&); -} // namespace device_conv2d_bwd_weight_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -161,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification, ck::is_same_v, float> && ck::is_same_v, float>) { - ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); } else if constexpr(ck::is_same_v, ck::half_t> && ck::is_same_v, ck::half_t> && ck::is_same_v, ck::half_t>) { - ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); } diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp index 5ea35cd72f1f10ef7644b324e1ba6d3d6a893c27..69bfe50a70d5923695a22dfdf0f2833407e70884 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -1,20 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_conv_fwd_bias_activation_add.hpp" -#include "reference_conv_fwd_bias_activation_add.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_bias_activation_add_instance { +namespace instance { using DeviceConvFwdBiasReluAddPtr = DeviceConvFwdBiasActivationAddPtr&); -} // namespace device_conv2d_fwd_bias_activation_add_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -176,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ck::is_same_v, ck::half_t> && ck::is_same_v, ck::half_t>) { - ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); } diff --git a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp deleted file mode 100644 index f1c2fd300accd5f651fa56f1898a27b44d6bf8e0..0000000000000000000000000000000000000000 --- a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp +++ /dev/null @@ -1,331 +0,0 @@ -#pragma once -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "device_conv_fwd_bias_activation.hpp" -#include "element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv2d_fwd_bias_activation_atomic_add_instance { - -using DeviceConvFwdBiasReluPtr = - DeviceConvFwdBiasActivationPtr; - -void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( - std::vector&); - -} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { - -void cpu_conv_bias_relu_atomic_add(ck::half_t* in_ptr, - ck::half_t* weight_ptr, - ck::half_t* output_ptr, - ck::half_t* bias_ptr, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const ck::index_t Y, - const ck::index_t X, - const ck::index_t Hi, - const ck::index_t Wi, - const ck::index_t Ho, - const ck::index_t Wo, - const ck::index_t Stride, - const ck::index_t Dilation, - const ck::index_t Pad) -{ - - const auto in_desc = - HostTensorDescriptor(std::vector{static_cast(N), - static_cast(Hi), - static_cast(Wi), - static_cast(C)}); - const auto wei_desc = - HostTensorDescriptor(std::vector{static_cast(K), - static_cast(Y), - static_cast(X), - static_cast(C)}); - const auto out_desc = - HostTensorDescriptor(std::vector{static_cast(N), - static_cast(Ho), - static_cast(Wo), - static_cast(K)}); - const auto bias_desc = - HostTensorDescriptor(std::vector{static_cast(K)}); - - auto f_k = [&](auto k) { - for(int n = 0; n < N; ++n) - { - for(int ho = 0; ho < Ho; ++ho) - { - for(int wo = 0; wo < Wo; ++wo) - { - double v = 0; - for(int c = 0; c < C; ++c) - { - for(int y = 0; y < Y; ++y) - { - int hi = ho * Stride + y * Dilation - Pad; - for(int x = 0; x < X; ++x) - { - int wi = wo * Stride + x * Dilation - Pad; - if(hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) - { - double in = - in_ptr[in_desc.GetOffsetFromMultiIndex(n, hi, wi, c)]; - double wei = - weight_ptr[wei_desc.GetOffsetFromMultiIndex(k, y, x, c)]; - - v += in * wei; - } - } - } - } - - v += bias_ptr[bias_desc.GetOffsetFromMultiIndex(k)]; - - v = v > 0 ? v : 0; - - output_ptr[out_desc.GetOffsetFromMultiIndex(n, ho, wo, k)] = v; - } - } - } - }; - - make_ParallelTensorFunctor(f_k, K)(std::thread::hardware_concurrency()); -} - -template -void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) -{ - const ck::index_t Y = filter_spatial_lengths[0]; - const ck::index_t X = filter_spatial_lengths[1]; - - const ck::index_t Hi = input_spatial_lengths[0]; - const ck::index_t Wi = input_spatial_lengths[1]; - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - auto f_host_tensor_descriptor = - [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { - if constexpr(is_same::value || - is_same::value || - is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(is_same::value || - is_same::value || - is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo_host_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - Tensor out_n_k_ho_wo_device_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - - // bias: assume contiguous 1d vector - Tensor bias_k( - HostTensorDescriptor(std::vector({static_cast(K)}))); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; - std::cout << "bias_k: " << bias_k.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - using InElementOp = ck::tensor_operation::element_wise::PassThrough; - using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; - using OutElementOp = ck::tensor_operation::element_wise::AddRelu; - - if(do_verification) - { - cpu_conv_bias_relu_atomic_add(in_n_c_hi_wi.mData.data(), - wei_k_c_y_x.mData.data(), - out_n_k_ho_wo_host_result.mData.data(), - bias_k.mData.data(), - N, - K, - C, - Y, - X, - Hi, - Wi, - Ho, - Wo, - conv_filter_strides[0], - conv_filter_dilations[0], - input_left_pads[0]); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); - DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - bias_device_buf.ToDevice(bias_k.mData.data()); - - using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: - DeviceConvFwdBiasActivationPtr; - - // add device operator instances - std::vector op_ptrs; - - if constexpr(ck::is_same_v, ck::half_t> && - ck::is_same_v, ck::half_t> && - ck::is_same_v, ck::half_t>) - { - ck::tensor_operation::device::device_conv2d_fwd_bias_activation_atomic_add_instance:: - add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( - op_ptrs); - } - - if(op_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device Conv instance found"); - } - - std::string best_conv_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device Conv instances - for(auto& op_ptr : op_ptrs) - { - auto argument_ptr = op_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(bias_device_buf.GetDeviceBuffer()), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string conv_name = op_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = - sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << conv_name << std::endl; - - if(tflops > best_tflops) - { - best_conv_name = conv_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - ck::utils::check_err(out_n_k_ho_wo_device_result.mData, - out_n_k_ho_wo_host_result.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") - << std::endl; - } - } - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp index eeb2b93e4eea6a14f754a33854dc659f53125baf..166173ca896adc6b08b3fdcfa04ff3f1bc55101e 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -1,19 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_conv_fwd_bias_activation.hpp" -#include "reference_conv_fwd_bias_activation.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_fwd_bias_activation_instance { +namespace instance { using DeviceConvFwdBiasReluPtr = DeviceConvFwdBiasActivationPtr&); -} // namespace device_conv2d_fwd_bias_activation_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -165,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ck::is_same_v, ck::half_t> && ck::is_same_v, ck::half_t>) { - ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); } diff --git a/profiler/include/profile_convnd_bwd_data_impl.hpp b/profiler/include/profile_convnd_bwd_data_impl.hpp index 291bf2abc086d3857d6dacea203e8c8e6525f43a..676e619b49dc262664de5b70d9f990a7cfb6cacc 100644 --- a/profiler/include/profile_convnd_bwd_data_impl.hpp +++ b/profiler/include/profile_convnd_bwd_data_impl.hpp @@ -1,23 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "config.hpp" -#include "device.hpp" -#include "conv_util.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "device_conv_bwd_data.hpp" -#include "element_wise_operation.hpp" -#include "reference_conv_bwd_data.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" using F16 = ck::half_t; using F32 = float; using BF16 = ck::bhalf_t; using INT8 = int8_t; + namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using DeviceConvBwdDataNoOpPtr = DeviceConvBwdDataPtr&); void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( std::vector&); -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck namespace ck { namespace profiler { -using DeviceConvBwdDataNoOpPtr = - ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr; +using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr; template HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, @@ -139,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr( switch(num_dim_spatial) { case 1: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); break; case 2: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); break; case 3: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); break; default: break; @@ -160,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr( switch(num_dim_spatial) { case 1: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); break; case 2: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); break; case 3: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); break; default: break; @@ -181,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr( switch(num_dim_spatial) { case 1: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); break; case 2: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); break; case 3: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); break; default: break; @@ -202,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr( switch(num_dim_spatial) { case 1: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); break; case 2: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); break; case 3: - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); break; default: break; diff --git a/profiler/include/profile_convnd_bwd_weight_impl.hpp b/profiler/include/profile_convnd_bwd_weight_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c32abd96b36735343b2ac8099f7be5c6bfa21c25 --- /dev/null +++ b/profiler/include/profile_convnd_bwd_weight_impl.hpp @@ -0,0 +1,478 @@ +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvndBwdWeightNoOpPtr = + DeviceConvBwdWeightPtr; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances( + std::vector&); +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector&); + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances( + std::vector&); +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector&); + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +using DeviceConvndBwdWeightNoOpPtr = + ck::tensor_operation::device::instance::DeviceConvndBwdWeightNoOpPtr; + +template +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +void get_device_conv_bwd_weight_op_ptr( + InDataType, WeiDataType, OutDataType, std::vector&, int) +{ + std::cout << "can not find device conv bwd weight" << std::endl; + exit(1); +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + break; + default: break; + } +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + break; + default: break; + } +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + break; + default: break; + } +} + +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_convnd_bwd_weight_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t split_k) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + std::vector input_dims{static_cast(N), static_cast(C)}; + input_dims.insert( + std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); + + std::vector filter_dims{static_cast(K), static_cast(C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths)); + + std::vector output_dims{static_cast(N), static_cast(K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor weights_host_result( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor weights_device_result( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor output( + get_output_host_ensor_descriptor(output_dims, NDimSpatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights_host_result.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + output.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_1{1}); + output.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + out_device_buf.ToDevice(output.mData.data()); + + // reset input to zero + wei_device_buf.SetZero(); + + if(do_verification) + { + auto RunReference = [&](auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights_host_result, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + }; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight(); + RunReference(ref_conv); + } + + // add device Conv instances + std::vector conv_ptrs; + get_device_conv_bwd_weight_op_ptr( + InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + // using atomic, so need to reset input, setzero is done in invoker + // if(split_k > 1) + //{ + // wei_device_buf.SetZero(); + //} + + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + if(!conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + continue; + } + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + std::string conv_name = conv_ptr->GetTypeString(); + float ave_time = 0; + + if(std::is_same::value && split_k > 1) + { + // alloc work space + size_t bwd_weight_workspace_size = conv_ptr->GetWorkSpaceSize(argument_ptr.get()); + if(bwd_weight_workspace_size <= 0) + { + printf("wrong work space size\n"); + exit(1); + } + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv_ptr->SetWorkSpacePointer(argument_ptr.get(), + wei_work_space_device_buf.GetDeviceBuffer()); + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + } + else + { + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + } + + std::size_t flop = + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + wei_device_buf.FromDevice(weights_device_result.mData.data()); + + float max_error = check_error(weights_host_result, weights_device_result); + + if(max_error > 8) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + + check_error(weights_host_result, weights_device_result); + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weights_host_result); + std::cout << std::endl; + + std::cout << "out : "; + show_data_nhwc_layout(input); + std::cout << std::endl; + + std::cout << "wei_device: "; + show_data_nhwc_layout(weights_device_result); + std::cout << std::endl; + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + return success; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_convnd_fwd.hpp b/profiler/include/profile_convnd_fwd.hpp deleted file mode 100644 index a3b55a79d1f6ffe0143f768ce4ae95a7201835cd..0000000000000000000000000000000000000000 --- a/profiler/include/profile_convnd_fwd.hpp +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -namespace ck { -namespace profiler { - -int profile_convnd_fwd(int argc, char* argv[]); - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..849b6f3ea287357efd038ac5ee4e235202a7eabd --- /dev/null +++ b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template // assume Ds and E have same layout +bool profile_gemm_add_add_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddAddFastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + DELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && + ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_2d_impl.hpp b/profiler/include/profile_gemm_bias_2d_impl.hpp deleted file mode 100644 index 8565f9637c3b8bcc3014ccb47f03b00bf122c3f2..0000000000000000000000000000000000000000 --- a/profiler/include/profile_gemm_bias_2d_impl.hpp +++ /dev/null @@ -1,314 +0,0 @@ -#pragma once - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm_bias.hpp" -#include "reference_gemm_bias_2d.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr< - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::AlphaBetaAdd>; - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( - std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { - -template -void profile_gemm_bias_2d_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC, - float alpha, - float beta) -{ - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - - std::size_t num_thread = 1; - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - } - - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{alpha, beta}; - - if(do_verification) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c0_m_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c0_device_buf(sizeof(C0DataType) * c0_m_n.mDesc.GetElementSpace()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - c0_device_buf.ToDevice(c0_m_n.mData.data()); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - - // add device GEMM instances - std::vector - gemm_ptrs; - - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } - } - - if(gemm_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) - { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c0_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string gemm_name = gemm_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c0 : ", c0_m_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; - } - } - } - else - { - std::cout << "does not support this GEMM problem" << std::endl; - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profile_gemm_bias_add_reduce_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..34317c59a7b383f2aa713d50214a2feeed012e2c --- /dev/null +++ b/profiler/include/profile_gemm_bias_add_reduce_impl.hpp @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmBiasAddReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_add_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideD0) +{ + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor d0_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduce0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduce0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + using D0ElementOp = PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto d0_element_op = D0ElementOp{}; + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + ReduceAccDataType acc = static_cast(c_m_n_host_result(m, n)) + + static_cast(bias_n(n)); + + ReduceAccDataType d0 = static_cast(d0_m_n(m, n)); + c_element_op(acc, acc); + d0_element_op(d0, d0); + acc += d0; + c_m_n_host_result(m, n) = static_cast(acc); + } + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpace()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + gemm_element_ops, + {&d0_element_op}, + reduce_in_element_ops, + reduce_out_element_ops); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData); + ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_host: ", reduce0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", reduce0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_host: ", reduce1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", reduce1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_relu_add_impl.hpp b/profiler/include/profile_gemm_bias_relu_add_impl.hpp deleted file mode 100644 index 6fec17c199306321a0da68638460f3bb5398759a..0000000000000000000000000000000000000000 --- a/profiler/include/profile_gemm_bias_relu_add_impl.hpp +++ /dev/null @@ -1,289 +0,0 @@ -#pragma once - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm_bias_activation_add.hpp" -#include "reference_gemm_bias_activation_add.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr< - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::AddReluAdd>; - -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( - std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { - -template -void profile_gemm_bias_relu_add_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC, - int StrideC1, - int KBatch = 1) -{ - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - // c0_n[n] - Tensor c0_n(HostTensorDescriptor( - std::vector({static_cast(N)}), std::vector({1}))); - - // c1_m_n[m ,n] - Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_n: " << c0_n.mDesc << std::endl; - std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}); - - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - - if(do_verification) - { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - b_k_n, - c_m_n_host_result, - c0_n, - c1_m_n, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); - DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_n_device_buf.ToDevice(c0_n.mData.data()); - c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); - - // add device GEMM instances - std::vector - gemm_ptrs; - - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( - gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( - gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( - gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( - gemm_ptrs); - } - } - - if(gemm_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) - { - auto argument_ptr = gemm_ptr->MakeArgumentPointer( - static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(c0_n_device_buf.GetDeviceBuffer()), - static_cast(c1_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - StrideC1, - a_element_op, - b_element_op, - c_element_op, - KBatch); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string gemm_name = gemm_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N + sizeof(CDataType) * N + - sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "a: ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c0: ", c0_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c1: ", c1_m_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; - } - } - } - else - { - std::cout << "does not support this GEMM problem" << std::endl; - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_gemm_bias_relu_impl.hpp b/profiler/include/profile_gemm_bias_relu_impl.hpp deleted file mode 100644 index 69010becc5bda4f076faff225d78a4129c074be5..0000000000000000000000000000000000000000 --- a/profiler/include/profile_gemm_bias_relu_impl.hpp +++ /dev/null @@ -1,267 +0,0 @@ -#pragma once - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm_bias_activation.hpp" -#include "reference_gemm_bias_activation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr< - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::AddRelu>; - -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( - std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { - -template -void profile_gemm_bias_relu_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC, - int KBatch = 1) -{ - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - // c0_n[n] - Tensor c0_n(HostTensorDescriptor( - std::vector({static_cast(N)}), std::vector({1}))); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_n: " << c0_n.mDesc << std::endl; - - std::size_t num_thread = 1; - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::AddRelu; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - - if(do_verification) - { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmBiasActivation; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_n_device_buf.ToDevice(c0_n.mData.data()); - - // add device GEMM instances - std::vector - gemm_ptrs; - - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - } - - if(gemm_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) - { - auto argument_ptr = gemm_ptr->MakeArgumentPointer( - static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(c0_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - KBatch); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string gemm_name = gemm_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N + sizeof(CDataType) * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c0 : ", c0_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; - } - } - } - else - { - std::cout << "does not support this GEMM problem" << std::endl; - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_gemm_bilinear_impl.hpp b/profiler/include/profile_gemm_bilinear_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f273ff4417cbe36f1cd19c22888712b2db2a3886 --- /dev/null +++ b/profiler/include/profile_gemm_bilinear_impl.hpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template // assume Ds and E have same layout +bool profile_gemm_bilinear_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD, + int StrideE, + float alpha, + float beta) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = Bilinear; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + DELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && + ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index a3400f89b3cadf54dcf56d7dd054e7ad80c008e2..54b9e05c067e5070ee33100e31b869c6d5f54c7b 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,119 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm.hpp" -#include "reference_gemm.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { namespace profiler { template -void profile_gemm_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC, - int KBatch) +int profile_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) { + bool pass = true; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) @@ -130,32 +59,25 @@ void profile_gemm_impl(int do_verification, Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; - std::size_t num_thread = 1; switch(init_method) { - // case 0: break; - case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; + case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -172,303 +94,72 @@ void profile_gemm_impl(int do_verification, b_device_buf.ToDevice(b_k_n.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - // add device GEMM instances - std::vector gemm_ptrs; + using DeviceOp = ck::tensor_operation::device::DeviceGemm; - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) + // Run reference GEMM + if(do_verification) { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); - } - } + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - if(gemm_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); + ref_invoker.Run(ref_argument); } - std::string best_gemm_name; + std::string best_op_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) + for(auto& op_ptr : op_ptrs) { auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - KBatch); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // re-init C to zero before profiling next kernel - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c_device_buf.SetZero(); - std::string gemm_name = gemm_ptr->GetTypeString(); + std::string op_name = op_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -483,11 +174,11 @@ void profile_gemm_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << gemm_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << std::endl; if(tflops > best_tflops) { - best_gemm_name = gemm_name; + best_op_name = op_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; @@ -497,86 +188,15 @@ void profile_gemm_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_f32_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - bf16_to_f32_(a_m_k, a_f32_m_k); - bf16_to_f32_(b_k_n, b_f32_k_n); - bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k, - b_f32_k_n, - c_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - - ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - } - } - else - { - Tensor c_m_n_host_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - } - } + pass = + pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl; } @@ -584,8 +204,7 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" - << std::endl; + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; } } @@ -627,7 +246,9 @@ void profile_gemm_impl(int do_verification, std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " - << best_gemm_name << std::endl; + << best_op_name << std::endl; + + return pass ? 0 : 1; } } // namespace profiler diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index dbdc9fd9d8b54e40203f79bb3a736ab76a4bc4d1..0f891a7aeeb8676dbf2a9d7c44df73ea871c6b96 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -1,37 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "device_gemm_reduce.hpp" -#include "reference_gemm.hpp" + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_instance { - -using F32 = float; -using F16 = ck::half_t; -using DPtrsGlobal = ck::Tuple; -using Div = ck::tensor_operation::element_wise::UnaryIdentic; -using Identity = ck::tensor_operation::element_wise::UnaryIdentic; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; - -using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< - DPtrsGlobal, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - DInElementOps, - DOutElementOps>; +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); @@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); -} // namespace device_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -56,7 +56,7 @@ namespace profiler { template @@ -91,22 +91,22 @@ bool profile_gemm_reduce_impl(int do_verification, Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d0_m_host_result( + Tensor reduce0_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d1_m_host_result( + Tensor reduce1_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d0_m_device_result( + Tensor reduce0_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d1_m_device_result( + Tensor reduce1_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; - std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; std::size_t num_thread = 1; switch(init_method) @@ -123,38 +123,41 @@ bool profile_gemm_reduce_impl(int do_verification, b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using D0ReduceOp = ck::reduce::Add; - using D1ReduceOp = ck::reduce::Add; - using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic; - using UnaryIdenticElementOp = - ck::tensor_operation::element_wise::UnaryIdentic; - using UnarySquareElementOp = - ck::tensor_operation::element_wise::UnarySquare; - using DxsInElementOps = ck::Tuple; - using DxsOutElementOps = ck::Tuple; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - const auto d0_reduce_op = D0ReduceOp{}; - const auto d1_reduce_op = D1ReduceOp{}; - - auto dxs_in_element_op = DxsInElementOps{}; - auto dxs_out_element_op = DxsOutElementOps{M, M}; + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; if(do_verification) { using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + using ReduceAccDataType = ReduceDataType; + auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -165,43 +168,43 @@ bool profile_gemm_reduce_impl(int do_verification, for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetIdentityValue(); - float d1_acc = d1_reduce_op.GetIdentityValue(); + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { - float c_val = ck::type_convert(c_m_n_host_result(m, n)); - float d0_val = 0; - float d1_val = 0; - - dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); - dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); - d0_reduce_op(d0_acc, d0_val); - d1_reduce_op(d1_acc, d1_val); + ReduceAccDataType d0_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); } - dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); - dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); - d0_m_host_result(m) = ck::type_convert(d0_acc); - d1_m_host_result(m) = ck::type_convert(d1_acc); + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); } } DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); - DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpace()); - auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer())); + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); // add device GEMM instances - std::vector - gemm_ptrs; + std::vector gemm_ptrs; if constexpr(is_same::value && is_same::value && is_same::value) @@ -210,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( gemm_ptrs); } @@ -218,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( gemm_ptrs); } @@ -226,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( gemm_ptrs); } @@ -234,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( gemm_ptrs); } @@ -253,30 +256,31 @@ bool profile_gemm_reduce_impl(int do_verification, // profile device GEMM instances for(auto& gemm_ptr : gemm_ptrs) { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - dxs_global, - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - dxs_in_element_op, - dxs_out_element_op); + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -306,16 +310,12 @@ bool profile_gemm_reduce_impl(int do_verification, if(do_verification) { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d0_device_buf.FromDevice(d0_m_device_result.mData.data()); - d1_device_buf.FromDevice(d1_m_device_result.mData.data()); - - float c_error = check_error(c_m_n_host_result, c_m_n_device_result); - float d0_error = check_error(d0_m_host_result, d0_m_device_result); - float d1_error = check_error(d1_m_host_result, d1_m_device_result); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); - pass = pass && (c_error < 1E-6); - pass = pass && (d0_error < 1E-6); - pass = pass && (d1_error < 1E-6); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData); + ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData); if(do_log) { @@ -325,13 +325,17 @@ bool profile_gemm_reduce_impl(int do_verification, << std::endl; LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl; - LogRangeAsType(std::cout << "d0_host: ", d0_m_host_result.mData, ",") + LogRangeAsType( + std::cout << "d0_host: ", reduce0_m_host_result.mData, ",") << std::endl; - LogRangeAsType(std::cout << "d0_device: ", d0_m_device_result.mData, ",") + LogRangeAsType( + std::cout << "d0_device: ", reduce0_m_device_result.mData, ",") << std::endl; - LogRangeAsType(std::cout << "d1_host: ", d1_m_host_result.mData, ",") + LogRangeAsType( + std::cout << "d1_host: ", reduce1_m_host_result.mData, ",") << std::endl; - LogRangeAsType(std::cout << "d1_device: ", d1_m_device_result.mData, ",") + LogRangeAsType( + std::cout << "d1_device: ", reduce1_m_device_result.mData, ",") << std::endl; } } diff --git a/profiler/include/profile_gemm_splitk_impl.hpp b/profiler/include/profile_gemm_splitk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8be879dcbe87472693065b34115ebd87200793f7 --- /dev/null +++ b/profiler/include/profile_gemm_splitk_impl.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_splitk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = + pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp index 8806e8ff438bf16eec65362a8268a3aa66e69bc8..6a92b3824cb0b8b2c1c50a949f5a7392f36f3ff7 100644 --- a/profiler/include/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -1,22 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm.hpp" -#include "reference_gemm.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_grouped_gemm_instance { +namespace instance { using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr< ck::tensor_operation::element_wise::PassThrough, @@ -32,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( std::vector&); -} // namespace device_grouped_gemm_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -167,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification, } // add device GEMM instances - std::vector< - ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr> - gemm_ptrs; + std::vector gemm_ptrs; if constexpr(is_same::value && is_same::value && is_same::value) @@ -178,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_grouped_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { - ck::tensor_operation::device::device_grouped_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { - ck::tensor_operation::device::device_grouped_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { - ck::tensor_operation::device::device_grouped_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } } @@ -228,6 +230,10 @@ void profile_grouped_gemm_impl(int do_verification, auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { std::string gemm_name = gemm_ptr->GetTypeString(); diff --git a/profiler/include/profile_normalization_impl.hpp b/profiler/include/profile_normalization_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e864698c15a167f113ae3fe91ef9f1edb7efc46 --- /dev/null +++ b/profiler/include/profile_normalization_impl.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_instances(std::vector&); +void add_device_softmax_f16_f16_rank4_instances(std::vector&); + +void add_device_softmax_f32_f32_rank3_instances(std::vector&); +void add_device_softmax_f32_f32_rank4_instances(std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +enum struct NormType +{ + LAYERNORM, + BATCHNORM, + SOFTMAX, +}; + +enum struct NormDataType +{ + F32_F32, // in, out + F16_F16, + BF16_BF16, + INT8_INT8, +}; + +// clang-format off +template std::string type_to_string(); +template <> std::string type_to_string() { return "f32"; } +template <> std::string type_to_string() { return "f16"; } +template <> std::string type_to_string() { return "bf16"; } +template <> std::string type_to_string() { return "int8"; } +template <> std::string type_to_string() { return "int32"; } +// clang-format on + +template +void profile_normalization_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, + std::vector in_strides, + std::vector reduce_dims, + AccDataType alpha, + AccDataType beta, + NormType norm_type) +{ + Tensor in = in_strides.empty() ? Tensor(in_length) + : Tensor(in_length, in_strides); + Tensor out(in.mDesc); + + switch(init_method) + { + // case 0: break; + case 0: + in.GenerateTensorValue(GeneratorTensor_1{}); + out.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor out_ref(out); + + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + in_dev.ToDevice(in.mData.data()); + out_dev.ToDevice(out.mData.data()); + + std::vector i_in_lengths(in.mDesc.GetLengths().begin(), in.mDesc.GetLengths().end()); + std::vector i_in_strides(in.mDesc.GetStrides().begin(), in.mDesc.GetStrides().end()); + + // add device normalization instances + std::vector instances; + + if(norm_type == NormType::SOFTMAX) + { + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if(in_length.size() == 3) + tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances( + instances); + + if(in_length.size() == 4) + tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances( + instances); + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if(in_length.size() == 3) + tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances( + instances); + + if(in_length.size() == 4) + tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances( + instances); + } + } + + if(instances.size() <= 0) + { + throw std::runtime_error("wrong! no device normalization instance found"); + } + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + for(auto& inst_ptr : instances) + { + // Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3 + // problem to rank 4 kernel) other than invoking IsSupportedArgument()? + if(!(inst_ptr->GetRank() == static_cast(i_in_lengths.size()) && + inst_ptr->GetNumReduceDim() == static_cast(reduce_dims.size()))) + { + continue; + } + + auto argument_ptr = inst_ptr->MakeArgumentPointer(i_in_lengths, + i_in_strides, + reduce_dims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = [", in_length, ", ") + << "], " + << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + return; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + in.mDesc.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + // TODO: factory method to dynamically switch between different reference normalizations + using ReferenceFactory = + tensor_operation::host::ReferenceSoftmax; + + ReferenceFactory{}.MakeInvoker().Run({in, out_ref, alpha, beta, reduce_dims}); + + out_dev.FromDevice(out.mData.data()); + + bool pass; + if(std::is_same::value) + { + pass = ck::utils::check_err( + out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1); + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_ref : ", out_ref.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + } + } + else + { + pass = ck::utils::check_err(out.mData, out_ref.mData); + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_ref : ", out_ref.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + } + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "input lengths = [", in_length, ", ") + << "], " + << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + } + } + } + std::cout << "Best Perf for datatype = " << type_to_string() << "_" + << type_to_string() << ", "; + LogRange(std::cout << "length = ", i_in_lengths, ",") << ", "; + LogRange(std::cout << "stride = ", i_in_strides, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", "; + std::cout << "alpha = " << alpha << ", " + << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_instance_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp index fd519d10333ba27c5d85ee5d0138414c48b4e406..a88b4bcd075b8051a796dbb5c22f88c41b04e6e8 100644 --- a/profiler/include/profile_reduce_impl.hpp +++ b/profiler/include/profile_reduce_impl.hpp @@ -1,17 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -#include "check_err.hpp" -#include "device_reduce.hpp" -#include "device_reduce_instance.hpp" -#include "reduction_enums.hpp" -#include "host_reduction.hpp" -#include "host_common_util.hpp" -#include "host_tensor_generator.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_reduction.hpp" +#include "ck/library/host_tensor/host_common_util.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_reduce_instance { +namespace instance { template struct ReduceDescription @@ -86,7 +91,7 @@ bool description_match(const DescriptionType& description, return (result); }; -} // namespace device_reduce_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -137,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification, float beta) { using namespace ck::tensor_operation::device; - using namespace ck::tensor_operation::device::device_reduce_instance; + using namespace ck::tensor_operation::device::instance; using ck::host_common::dumpBufferToFile; constexpr bool op_support_indices = @@ -261,13 +266,18 @@ bool profile_reduce_impl_impl(bool do_verification, float best_gb_per_sec = 0; using InElementwiseOperation = - typename reduce_unary_operator:: - InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; + + using ReduceOperation = typename reduce_binary_operator::opType; - using ReduceOperation = typename reduce_binary_operator::opType; + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); using DeviceReduceInstPtr0 = DeviceReducePtr; @@ -323,8 +333,13 @@ bool profile_reduce_impl_impl(bool do_verification, OutputIndex> hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); - hostReduce.Run( - alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); }; std::vector i_inLengths; @@ -339,10 +354,6 @@ bool profile_reduce_impl_impl(bool do_verification, for(auto& reduce_ptr : reduce0_ptrs) { - - InElementwiseOperation in_elementwise_op(static_cast(reduce_total_length)); - AccElementwiseOperation acc_elementwise_op(static_cast(reduce_total_length)); - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, i_inStrides, i_outLengths, @@ -453,7 +464,7 @@ bool profile_reduce_impl(bool do_verification, bool pass = true; using tuple_of_description_instances = - tensor_operation::device::device_reduce_instance::reduce_description_instances; + tensor_operation::device::instance::reduce_description_instances; const auto tuple_object = tuple_of_description_instances{}; diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index fbdc07c3da1f71473d1d9ae6fec0cb3c87686662..7c4e2f7b7d8ab921c73dc1b6aaa1bb56f4f475d6 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -1,20 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include #include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_batched_gemm_xdl.hpp" -#include "profile_batched_gemm_impl.hpp" + +#include "profiler/include/profile_batched_gemm_impl.hpp" enum struct GemmMatrixLayout { @@ -22,10 +15,6 @@ enum struct GemmMatrixLayout MK_NK_MN, // 1 KM_KN_MN, // 2 KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 }; enum struct GemmDataType @@ -38,8 +27,9 @@ enum struct GemmDataType int profile_batched_gemm(int argc, char* argv[]) { - if(!(argc == 15)) + if(argc != 18) { + // clang-format off printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n"); printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n"); printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n"); @@ -50,7 +40,8 @@ int profile_batched_gemm(int argc, char* argv[]) printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n"); - printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); + printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); + // clang-format on exit(1); } @@ -69,332 +60,138 @@ int profile_batched_gemm(int argc, char* argv[]) const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); - const int BatchCount = std::stoi(argv[14]); + const int BatchStrideA = std::stoi(argv[14]); + const int BatchStrideB = std::stoi(argv[15]); + const int BatchStrideC = std::stoi(argv[16]); + + const int BatchCount = std::stoi(argv[17]); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + bool pass = ck::profiler:: + profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); } else { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); - } + std::cout << "this data_type & layout is not implemented" << std::endl; - return 0; + return 1; + } } diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp index 594fc6bedb6ff95e0447c05f9275f8ada458fcbf..7c518e979bb5a8b29ede136bb3bfdf0144da0cf8 100644 --- a/profiler/src/profile_batched_gemm_reduce.cpp +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -1,11 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_batched_gemm_reduce_impl.hpp" +#include "profiler/include/profile_batched_gemm_reduce_impl.hpp" int profile_batched_gemm_reduce(int argc, char* argv[]) { diff --git a/profiler/src/profile_conv_bwd_weight.cpp b/profiler/src/profile_conv_bwd_weight.cpp index 80413322b301d34b208054d7206558dfbe029aec..989c480886b5260508b29cfe95355b2ba5050960 100644 --- a/profiler/src/profile_conv_bwd_weight.cpp +++ b/profiler/src/profile_conv_bwd_weight.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_conv_bwd_weight_impl.hpp" + +#include "profiler/include/profile_conv_bwd_weight_impl.hpp" enum struct ConvDataType { diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp index ca7dc1935ae1bd5980738efd21f7fc706656da14..91f4836a2bcaffd918fc3b202d5945783229e021 100644 --- a/profiler/src/profile_conv_fwd_bias_relu.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_conv_fwd_bias_relu_impl.hpp" + +#include "profiler/include/profile_conv_fwd_bias_relu_impl.hpp" enum struct ConvDataType { diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp index 5d75f5a2943dd2cb6bd1fada223b7c3db9b49930..5cc6faba346e4eb08ec417582b6bfac46f0c48ec 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_conv_fwd_bias_relu_add_impl.hpp" + +#include "profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp" enum struct ConvDataType { diff --git a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp deleted file mode 100644 index 96d3b10ddfab5469cf47e7777a2b7ca4ed3cdda1..0000000000000000000000000000000000000000 --- a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp" - -enum struct ConvDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 -}; - -enum struct ConvInputLayout -{ - NCHW, // 0 - NHWC, // 1 -}; - -enum struct ConvWeightLayout -{ - KCYX, // 0 - KYXC, // 1 -}; - -enum struct ConvOutputLayout -{ - NKHW, // 0 - NHWK, // 1 -}; - -int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) -{ - if(argc != 25) - { - printf("arg1: tensor operation (conv_fwd_bias_relu_atomic_add: " - "ForwardConvolution+Bias+ReLu+AtomicAdd)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); - printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); - printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); - printf("arg6: verification (0: no; 1: yes)\n"); - printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: time kernel (0=n0, 1=yes)\n"); - printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - const auto in_layout = static_cast(std::stoi(argv[3])); - const auto wei_layout = static_cast(std::stoi(argv[4])); - const auto out_layout = static_cast(std::stoi(argv[5])); - const bool do_verification = std::stoi(argv[6]); - const int init_method = std::stoi(argv[7]); - const bool do_log = std::stoi(argv[8]); - const bool time_kernel = std::stoi(argv[9]); - - const ck::index_t N = std::stoi(argv[10]); - const ck::index_t K = std::stoi(argv[11]); - const ck::index_t C = std::stoi(argv[12]); - const ck::index_t Y = std::stoi(argv[13]); - const ck::index_t X = std::stoi(argv[14]); - const ck::index_t Hi = std::stoi(argv[15]); - const ck::index_t Wi = std::stoi(argv[16]); - - const ck::index_t conv_stride_h = std::stoi(argv[17]); - const ck::index_t conv_stride_w = std::stoi(argv[18]); - const ck::index_t conv_dilation_h = std::stoi(argv[19]); - const ck::index_t conv_dilation_w = std::stoi(argv[20]); - const ck::index_t in_left_pad_h = std::stoi(argv[21]); - const ck::index_t in_left_pad_w = std::stoi(argv[22]); - const ck::index_t in_right_pad_h = std::stoi(argv[23]); - const ck::index_t in_right_pad_w = std::stoi(argv[24]); - - const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; - const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && - wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) - { - ck::profiler::profile_conv_fwd_bias_relu_atomic_add_impl< - 2, - ck::half_t, - ck::half_t, - ck::half_t, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( - do_verification, - init_method, - do_log, - time_kernel, - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - std::vector{conv_stride_h, conv_stride_w}, - std::vector{conv_dilation_h, conv_dilation_w}, - std::vector{in_left_pad_h, in_left_pad_w}, - std::vector{in_right_pad_h, in_right_pad_w}); - } - else - { - throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); - } - - return 0; -} diff --git a/profiler/src/profile_convnd_bwd_data.cpp b/profiler/src/profile_convnd_bwd_data.cpp index 5d0e6a34c7bdc70a26399d3f6637d45b153c941a..7c387d375e6ec7cca5dd7d7de4a79e58ff460cd4 100644 --- a/profiler/src/profile_convnd_bwd_data.cpp +++ b/profiler/src/profile_convnd_bwd_data.cpp @@ -1,11 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_convnd_bwd_data_impl.hpp" +#include "profiler/include/profile_convnd_bwd_data_impl.hpp" namespace { diff --git a/profiler/src/profile_convnd_bwd_weight.cpp b/profiler/src/profile_convnd_bwd_weight.cpp new file mode 100644 index 0000000000000000000000000000000000000000..741d9ac656fb4ff98a02de6eb64762fcf555103a --- /dev/null +++ b/profiler/src/profile_convnd_bwd_weight.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_convnd_bwd_weight_impl.hpp" + +namespace { + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // namespace + +int profile_convnd_bwd_weight(int argc, char* argv[], int num_dim_spatial) +{ + const int preParams = 11; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + preParams; + if(cmdline_nargs != argc) + { + printf("arg1: tensor operation (convnd[1|2|3]d_bwd_weight: BackwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16, 2: bf16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10: splitk\n"); + printf("arg11 to 25: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + ck::index_t split_k = std::stoi(argv[10]); + split_k = std::max(1, split_k); + + ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); + + auto Run = [&](auto input_type, auto wei_type, auto out_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + switch(num_dim_spatial) + { + case 1: + ck::profiler::profile_convnd_bwd_weight_impl<1, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + split_k); + break; + + case 2: + ck::profiler::profile_convnd_bwd_weight_impl<2, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + split_k); + break; + + case 3: + ck::profiler::profile_convnd_bwd_weight_impl<3, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + split_k); + break; + + default: break; + } + }; + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(float{}, float{}, float{}); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::half_t{}, ck::half_t{}, ck::half_t{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}); + } + else + { + std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl; + return 1; + } + + return 0; +} diff --git a/profiler/src/profile_convnd_fwd.cpp b/profiler/src/profile_convnd_fwd.cpp index 87778a04a537f4afc607bbea86f997b44f44d2d4..8223be160ed21161a9c45a59335868d308ffb327 100644 --- a/profiler/src/profile_convnd_fwd.cpp +++ b/profiler/src/profile_convnd_fwd.cpp @@ -1,15 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include +#include #include #include #include #include -#include -#include "conv_util.hpp" -#include "element_wise_operation.hpp" -#include "fill.hpp" -#include "profile_convnd_fwd.hpp" -#include "tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/fill.hpp" namespace { @@ -150,9 +153,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::utils::FillUniform, - ck::utils::FillUniform>>( - params, true, ck::utils::FillUniform{}, ck::utils::FillUniform{}); + ck::utils::FillUniformDistributionIntegerValue, + ck::utils::FillUniformDistributionIntegerValue>>( + params, + true, + ck::utils::FillUniformDistributionIntegerValue{}, + ck::utils::FillUniformDistributionIntegerValue{}); break; case 2: conv_instance = std::make_unique< @@ -165,12 +171,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::utils::FillUniform, - ck::utils::FillUniform>>( + ck::utils::FillUniformDistribution, + ck::utils::FillUniformDistribution>>( params, true, - ck::utils::FillUniform{}, - ck::utils::FillUniform{}); + ck::utils::FillUniformDistribution{}, + ck::utils::FillUniformDistribution{}); break; default: throw std::runtime_error("Unsupported init method!"); } @@ -181,8 +187,10 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, _1, _2, _3); - OpInstanceRunEngine run_engine(*conv_instance, - reference_conv_fwd_fun); + + OpInstanceRunEngine run_engine( + *conv_instance, reference_conv_fwd_fun, do_verification); + auto best_conf = run_engine.Profile( conv::ConvolutionFwdInstances::template Get(), time_kernel, @@ -295,7 +303,7 @@ void profile_convnd_instances(ConvDataType data_type, } // namespace -int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) +int profile_convnd_fwd(int argc, char* argv[]) { using namespace ck::utils::conv; diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 0684e183221a631121c25b88f19d21fca74a3906..624f3dbf6117fa0a5760fd7c35d6365547c18cf6 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_gemm_impl.hpp" + +#include "profiler/include/profile_gemm_impl.hpp" enum struct GemmMatrixLayout { @@ -12,10 +14,6 @@ enum struct GemmMatrixLayout MK_NK_MN, // 1 KM_KN_MN, // 2 KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 }; enum struct GemmDataType @@ -28,7 +26,7 @@ enum struct GemmDataType int profile_gemm(int argc, char* argv[]) { - if(!(argc == 14 || argc == 15)) + if(argc != 14) { printf("arg1: tensor operation (gemm: GEMM)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); @@ -39,9 +37,8 @@ int profile_gemm(int argc, char* argv[]) printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); - printf("arg14: split k into mulitiple batch\n"); exit(1); } @@ -59,350 +56,125 @@ int profile_gemm(int argc, char* argv[]) const int StrideA = std::stoi(argv[11]); const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); - int KBatch = 1; - if(argc == 15) - KBatch = std::stoi(argv[14]); - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + using INT32 = int32_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = + ck::profiler::profile_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Col{}, Row{}); } else { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); - } + std::cout << "this data_type & layout is not implemented" << std::endl; - return 0; + return 1; + } } diff --git a/profiler/src/profile_gemm_add_add_fastgelu.cpp b/profiler/src/profile_gemm_add_add_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a381222cbc5e4ace3766b285438e3ad144588a0b --- /dev/null +++ b/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_add_add_fastgelu_impl.hpp" + +int profile_gemm_add_add_fastgelu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN_MN, // 0 + MK_NK_MN_MN_MN, // 1 + KM_KN_MN_MN_MN, // 2 + KM_NK_MN_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 16) + { + // clang-format off + printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n"); + printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 3: E[m, n] = FastGeLU(A[k, m] * B[n, k] + D0[m, n] + D1[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideE = std::stoi(argv[15]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto e_type, + auto a_layout, + auto b_layout, + auto de_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using DELayout = decltype(de_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_gemm_bias_2d.cpp b/profiler/src/profile_gemm_bias_2d.cpp deleted file mode 100644 index 51dba85f326cafdd28f0f19b4b92958e8dc423f2..0000000000000000000000000000000000000000 --- a/profiler/src/profile_gemm_bias_2d.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "profile_gemm_bias_2d_impl.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 -}; - -enum struct GemmDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 -}; - -int profile_gemm_bias_2d(int argc, char* argv[]) -{ - if(!(argc == 16 || argc == 17)) - { - printf("arg1: tensor operation (gemm: GEMM+Bias_2d)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=n0, 1=yes)\n"); - printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); - printf("arg14: alpha\n"); - printf("arg15: beta\n"); - printf("arg16: split k into mulitiple batch\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - - const int M = std::stoi(argv[8]); - const int N = std::stoi(argv[9]); - const int K = std::stoi(argv[10]); - - const int StrideA = std::stoi(argv[11]); - const int StrideB = std::stoi(argv[12]); - const int StrideC = std::stoi(argv[13]); - - const float alpha = std::stof(argv[14]); - const float beta = std::stof(argv[15]); - - if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm_bias_2d_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - alpha, - beta); - } - else - { - throw std::runtime_error("wrong! this data_type & layout is not implemented"); - } - - return 0; -} diff --git a/profiler/src/profile_gemm_bias_add_reduce.cpp b/profiler/src/profile_gemm_bias_add_reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc2675703f6d7f9ddc96b415d6da1e52949da7b1 --- /dev/null +++ b/profiler/src/profile_gemm_bias_add_reduce.cpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_bias_add_reduce_impl.hpp" + +int profile_gemm_bias_add_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+bias+add+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int StrideC1 = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_bias_relu.cpp b/profiler/src/profile_gemm_bias_relu.cpp deleted file mode 100644 index bf035d9ad9a1c17b0f9032d26da74153efedb2b2..0000000000000000000000000000000000000000 --- a/profiler/src/profile_gemm_bias_relu.cpp +++ /dev/null @@ -1,143 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "profile_gemm_bias_relu_impl.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 -}; - -enum struct GemmDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 -}; - -int profile_gemm_bias_relu(int argc, char* argv[]) -{ - if(!(argc == 14 || argc == 15)) - { - printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=n0, 1=yes)\n"); - printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); - printf("arg14: split k into mulitiple batch\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - - const int M = std::stoi(argv[8]); - const int N = std::stoi(argv[9]); - const int K = std::stoi(argv[10]); - - const int StrideA = std::stoi(argv[11]); - const int StrideB = std::stoi(argv[12]); - const int StrideC = std::stoi(argv[13]); - - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm_bias_relu_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm_bias_relu_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm_bias_relu_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm_bias_relu_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); - } - else - { - throw std::runtime_error("wrong! this data_type & layout is not implemented"); - } - - return 0; -} diff --git a/profiler/src/profile_gemm_bias_relu_add.cpp b/profiler/src/profile_gemm_bias_relu_add.cpp deleted file mode 100644 index 9c324f6cf95aa8e7717dc7430583eaf258b90e41..0000000000000000000000000000000000000000 --- a/profiler/src/profile_gemm_bias_relu_add.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "profile_gemm_bias_relu_add_impl.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 -}; - -enum struct GemmDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 -}; - -int profile_gemm_bias_relu_add(int argc, char* argv[]) -{ - if(!(argc == 15 || argc == 16)) - { - printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU+Add)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=n0, 1=yes)\n"); - printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); - printf("arg15: split k into mulitiple batch\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - - const int M = std::stoi(argv[8]); - const int N = std::stoi(argv[9]); - const int K = std::stoi(argv[10]); - - const int StrideA = std::stoi(argv[11]); - const int StrideB = std::stoi(argv[12]); - const int StrideC = std::stoi(argv[13]); - const int StrideC1 = std::stoi(argv[14]); - - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm_bias_relu_add_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - (StrideC1 < 0) ? N : StrideC1); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm_bias_relu_add_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - (StrideC1 < 0) ? N : StrideC1); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm_bias_relu_add_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - (StrideC1 < 0) ? N : StrideC1); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm_bias_relu_add_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - (StrideC1 < 0) ? N : StrideC1); - } - else - { - throw std::runtime_error("wrong! this data_type & layout is not implemented"); - } - - return 0; -} diff --git a/profiler/src/profile_gemm_bilinear.cpp b/profiler/src/profile_gemm_bilinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..14c577897c0986c5f349a87649dbad871eabe1ef --- /dev/null +++ b/profiler/src/profile_gemm_bilinear.cpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_bilinear_impl.hpp" + +int profile_gemm_bilinear(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32, // 0 + F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 17) + { + // clang-format off + printf("arg1: tensor operation (gemm_bilinear: GEMM+Bilinear)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = alpha * A[m, k] * B[k, n] + beta * D[m, n];\n"); + printf(" 1: E[m, n] = alpha * A[m, k] * B[n, k] + beta * D[m, n];\n"); + printf(" 2: E[m, n] = alpha * A[k, m] * B[k, n] + beta * D[m, n];\n"); + printf(" 3: E[m, n] = alpha * A[k, m] * B[n, k] + beta * D[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + printf("arg15 to 16: alhpa, beta\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + const float alpha = std::stof(argv[15]); + const float beta = std::stof(argv[16]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d_type, + auto e_type, + auto a_layout, + auto b_layout, + auto de_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using DDataType = decltype(d_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using DELayout = decltype(de_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_bilinear_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD < 0) ? DefaultStrideD : StrideD, + (StrideE < 0) ? DefaultStrideE : StrideE, + alpha, + beta); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp index a23967acd7ac5ac99396b5e372128401916b71a4..476943c8a72db4d62070776c401f850c5d7da195 100644 --- a/profiler/src/profile_gemm_reduce.cpp +++ b/profiler/src/profile_gemm_reduce.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_gemm_reduce_impl.hpp" + +#include "profiler/include/profile_gemm_reduce_impl.hpp" int profile_gemm_reduce(int argc, char* argv[]) { diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fff023c8e0f6ea1d56377a8b837ca1397dcf6049 --- /dev/null +++ b/profiler/src/profile_gemm_splitk.cpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_splitk_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +int profile_gemm_splitk(int argc, char* argv[]) +{ + if(argc != 15) + { + printf("arg1: tensor operation (gemm_splitk: Split-K GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int KBatch = std::stoi(argv[14]); + + using F32 = float; + using F16 = ck::half_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_splitk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index ea73d446e38d07cdd0473f37557be44286ab8cf8..a51505ae9c6f9c548c81eaf81a6606dc99272f46 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include -#include "profile_grouped_gemm_impl.hpp" + +#include "profiler/include/profile_grouped_gemm_impl.hpp" enum struct GemmMatrixLayout { diff --git a/profiler/src/profile_normalization.cpp b/profiler/src/profile_normalization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..277a78a669a4e5cacee05c0b546fd340050567fd --- /dev/null +++ b/profiler/src/profile_normalization.cpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/include/profile_normalization_impl.hpp" + +using ck::index_t; +using ck::profiler::NormDataType; +using ck::profiler::NormType; + +struct ArgParser +{ + std::unordered_map norm_dict = {{"layernorm", NormType::LAYERNORM}, + {"batchnorm", NormType::BATCHNORM}, + {"softmax", NormType::SOFTMAX}}; + + std::unordered_map> long_opts = { + {"length", {}}, {"stride", {}}, {"reduce", {}}, {"alpha", {}}, {"beta", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help() +{ + std::cout << "arg1: tensor operation (layernorm/batchnorm/softmax)\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg3: verification (0: no; 1: yes)\n" + << "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg5: print tensor value (0: no; 1: yes)\n" + << "arg6: time kernel (0=n0, 1=yes)\n" + << "--length: tensor extents (e.g, --length 8 4 256) \n" + << "--stride: tensor strides (e.g, --stride 1024 256 1)\n" + << "--reduce: to-reduce dimensions (e.g, --reduce 2)\n" + << "--alpha: alpha scaling value\n" + << "--beta: beta scaling value\n" + << std::endl; +} + +int profile_normalization(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help(); + return 0; + } + + ArgParser arg_parser; + + // short unnamed options + const NormType norm_type = arg_parser.norm_dict[argv[1]]; + const NormDataType data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + const std::vector stride = arg_parser.long_opts["stride"]; + const std::vector reduce = arg_parser.long_opts["reduce"]; + const index_t alpha = + arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0]; + const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0]; + + if(data_type == NormDataType::F16_F16) + { + ck::profiler::profile_normalization_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta), + norm_type); + } + else if(data_type == NormDataType::F32_F32) + { + ck::profiler::profile_normalization_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta), + norm_type); + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +// hijack main() for quick debugging +// int main(int argc, char* argv[]) +// { +// profile_normalization(argc, argv); +// return 0; +// } diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp index bdbac4fab4ff7ae9431c14fa2b4cb00ba8df3f12..d31cdb74d8e81a60928dd4145effc3bb4e6b2204 100644 --- a/profiler/src/profile_reduce.cpp +++ b/profiler/src/profile_reduce.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include @@ -6,11 +9,12 @@ #include #include -#include "data_type_enum.hpp" -#include "reduction_enums.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/host_tensor/host_common_util.hpp" -#include "host_common_util.hpp" -#include "profile_reduce_impl.hpp" +#include "profiler/include/profile_reduce_impl.hpp" +#include "profiler/include/data_type_enum.hpp" using namespace std; diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 6640cf4ce6cf37c0ba6615bde84627e4cc56cbe5..046c6168ff6ac525179ae7156a6cc4fec2d38f10 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -1,49 +1,83 @@ -#include -#include -#include -#include -#include +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "profile_convnd_fwd.hpp" +#include int profile_gemm(int, char*[]); -int profile_gemm_bias_2d(int, char*[]); -int profile_gemm_bias_relu(int, char*[]); -int profile_gemm_bias_relu_add(int, char*[]); +int profile_gemm_splitk(int, char*[]); +int profile_gemm_bilinear(int, char*[]); +int profile_gemm_add_add_fastgelu(int, char*[]); int profile_gemm_reduce(int, char*[]); +int profile_gemm_bias_add_reduce(int, char*[]); int profile_batched_gemm(int, char*[]); +int profile_batched_gemm_reduce(int, char*[]); int profile_grouped_gemm(int, char*[]); int profile_conv_fwd(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]); -int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +int profile_convnd_fwd(int argc, char* argv[]); int profile_convnd_bwd_data(int, char*[], int); -int profile_reduce(int, char*[]); int profile_conv_bwd_weight(int, char*[]); -int profile_batched_gemm_reduce(int, char*[]); +int profile_normalization(int, char*[]); +int profile_reduce(int, char*[]); +int profile_convnd_bwd_weight(int, char*[], int); + +static void print_helper_message() +{ + // clang-format off + printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_splitk: Split-K GEMM\n" + " gemm_bilinear: GEMM+Bilinear\n" + " gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n" + " gemm_reduce: GEMM+Reduce\n" + " gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n" + " batched_gemm: Batched GEMM\n" + " batched_gemm_reduce: Batched GEMM+Reduce\n" + " grouped_gemm: Grouped GEMM\n" + " conv_fwd: ForwardConvolution\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" + " conv1d_bwd_data: BackwardConvolution data 1 dim\n" + " conv2d_bwd_data: BackwardConvolution data 2 dim\n" + " conv3d_bwd_data: BackwardConvolution data 3 dim\n" + " conv2d_bwd_weight: Backward Weight Convolution 2d\n" + " reduce: Reduce\n"); + // clang-format on +} int main(int argc, char* argv[]) { + if(argc == 1) + { + print_helper_message(); + + return 0; + } + if(strcmp(argv[1], "gemm") == 0) { return profile_gemm(argc, argv); } - else if(strcmp(argv[1], "gemm_bias_2d") == 0) + else if(strcmp(argv[1], "gemm_splitk") == 0) { - return profile_gemm_bias_2d(argc, argv); + return profile_gemm_splitk(argc, argv); } - else if(strcmp(argv[1], "gemm_bias_relu") == 0) + else if(strcmp(argv[1], "gemm_bilinear") == 0) { - return profile_gemm_bias_relu(argc, argv); + return profile_gemm_bilinear(argc, argv); } - else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) + else if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0) { - return profile_gemm_bias_relu_add(argc, argv); + return profile_gemm_add_add_fastgelu(argc, argv); } else if(strcmp(argv[1], "gemm_reduce") == 0) { return profile_gemm_reduce(argc, argv); } + else if(strcmp(argv[1], "gemm_bias_add_reduce") == 0) + { + return profile_gemm_bias_add_reduce(argc, argv); + } else if(strcmp(argv[1], "batched_gemm") == 0) { return profile_batched_gemm(argc, argv); @@ -58,7 +92,7 @@ int main(int argc, char* argv[]) } else if(strcmp(argv[1], "conv_fwd") == 0) { - return ck::profiler::profile_convnd_fwd(argc, argv); + return profile_convnd_fwd(argc, argv); } else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) { @@ -68,14 +102,6 @@ int main(int argc, char* argv[]) { return profile_conv_fwd_bias_relu_add(argc, argv); } - else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) - { - return profile_conv_fwd_bias_relu_atomic_add(argc, argv); - } - else if(strcmp(argv[1], "conv_fwd_cpu") == 0) - { - return profile_conv_fwd_cpu(argc, argv); - } else if(strcmp(argv[1], "conv1d_bwd_data") == 0) { return profile_convnd_bwd_data(argc, argv, 1); @@ -88,16 +114,34 @@ int main(int argc, char* argv[]) { return profile_convnd_bwd_data(argc, argv, 3); } + else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) + { + return profile_conv_bwd_weight(argc, argv); + } + else if(strcmp(argv[1], "convnd1d_bwd_weight") == 0) + { + return profile_convnd_bwd_weight(argc, argv, 1); + } + else if(strcmp(argv[1], "convnd2d_bwd_weight") == 0) + { + return profile_convnd_bwd_weight(argc, argv, 2); + } + else if(strcmp(argv[1], "convnd3d_bwd_weight") == 0) + { + return profile_convnd_bwd_weight(argc, argv, 3); + } else if(strcmp(argv[1], "reduce") == 0) { return profile_reduce(argc, argv); } - else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) + else if(strcmp(argv[1], "batchnorm") == 0 || strcmp(argv[1], "layernorm") == 0 || + strcmp(argv[1], "softmax") == 0) { - return profile_conv_bwd_weight(argc, argv); + return profile_normalization(argc, argv); } else { +<<<<<<< HEAD // clang-format off printf("arg1: tensor operation (gemm: GEMM\n" " gemm_bias_2d: GEMM+Bias(2D)\n" @@ -118,4 +162,10 @@ int main(int argc, char* argv[]) // clang-format on } return 0; +======= + print_helper_message(); + + return 0; + } +>>>>>>> origin/develop } diff --git a/script/docker-rocm4.1.sh b/script/docker-rocm4.1.sh deleted file mode 100755 index 61cc33c5b845cb47c56693eb491a65e9076a7597..0000000000000000000000000000000000000000 --- a/script/docker-rocm4.1.sh +++ /dev/null @@ -1,14 +0,0 @@ -WORKSPACE=$1 -echo "workspace: " $WORKSPACE - -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v $WORKSPACE:/root/workspace \ -rocm/tensorflow:rocm4.1-tf1.15-dev \ -/bin/bash - -#--network host \ diff --git a/script/docker-rocm4.3.1.sh b/script/docker-rocm4.3.1.sh deleted file mode 100755 index 48cb675b69056b1aa7e2e2af6aaf4a2c2a2373cc..0000000000000000000000000000000000000000 --- a/script/docker-rocm4.3.1.sh +++ /dev/null @@ -1,14 +0,0 @@ -WORKSPACE=$1 -echo "workspace: " $WORKSPACE - -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v $WORKSPACE:/root/workspace \ -rocm/tensorflow:rocm4.3.1-tf2.6-dev \ -/bin/bash - -#--network host \ diff --git a/script/profile_conv.sh b/script/profile_conv.sh index 0e97ceb6c651e133eafc2ca2096e6f53032c6459..c3ba39c9260bd1034b68c073f4a9fb18aa5e5ba8 100755 --- a/script/profile_conv.sh +++ b/script/profile_conv.sh @@ -26,7 +26,7 @@ REPEAT=$9 N=${10} -# Resnet50 from Bing +# Resnet50 (no duplicated layer) ######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 @@ -47,10 +47,10 @@ REPEAT=$9 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 -#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 8 7 7 224 224 2 2 1 1 3 3 3 3 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 -# Resnet50 from Bing +# Resnet50 fusion ####### op_________________ datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C_ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 diff --git a/script/run_performance_tests.sh b/script/run_performance_tests.sh old mode 100644 new mode 100755 index 6c96a9449d1c704b5efe16a54cf3c3b3f38b75fe..95d63d0ffe02698579b58f5383b42f5e647c091e --- a/script/run_performance_tests.sh +++ b/script/run_performance_tests.sh @@ -1,58 +1,57 @@ -#!/bin/bash -# -# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ -# and make sure the following python packages are installed in your environment: -# pip3 install --upgrade pip -# pip3 install sqlalchemy -# pip3 install pymysql -# pip3 install pandas -# pip3 install sshtunnel -# you would also need to set up some environment variables in order to -# post your new test results to the database and compare them to the baseline -# please contact Illia.Silin@amd.com for more details -# - -export gemm_log="perf_gemm.log" -rm -f $gemm_log -git status | grep -e 'On branch' > ${gemm_log} -echo -n 'Node name: ' >>${gemm_log}; hostname >> ${gemm_log} -#get GPU_arch and number of compute units from rocminfo -echo -n "GPU_arch: " >> ${gemm_log}; rocminfo | grep "Name:" | grep "gfx" >> ${gemm_log} -rocminfo | grep "Compute Unit:" >> ${gemm_log} -hipcc --version | grep -e 'HIP version' >> ${gemm_log} -/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log} -./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log} -./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a $gemm_log -./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a $gemm_log - -python3 parse_perf_data.py ${gemm_log} - -#run resnet50 test -export resnet_log="perf_resnet50.log" -rm -f $resnet_log -git status | grep -e 'On branch' > ${resnet_log} -echo -n 'Node name: '>>${resnet_log}; hostname >>${resnet_log} -#get GPU_arch and number of compute units from rocminfo -echo -n "GPU_arch: " >> ${resnet_log}; rocminfo | grep "Name:" | grep "gfx" >> ${resnet_log} -rocminfo | grep "Compute Unit:" >> ${resnet_log} -hipcc --version | grep -e 'HIP version' >> ${resnet_log} -/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log} -#first run tests with N=256 -./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log} -#then run with N=4 -./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log} -#the script will put the results from N=256 and N=4 runs into separate tables -python3 parse_perf_data.py ${resnet_log} +#!/bin/bash +# +# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ +# and make sure the following python packages are installed in your environment: + +pip3 install --upgrade pip +pip3 install sqlalchemy pymysql pandas sshtunnel + +# you would also need to set up some environment variables in order to +# post your new test results to the database and compare them to the baseline +# please contact Illia.Silin@amd.com for more details +# + +export gemm_log="perf_gemm.log" +rm -f $gemm_log +git status | grep -e 'On branch' > ${gemm_log} +echo -n 'Node name: ' >>${gemm_log}; hostname >> ${gemm_log} +#get GPU_arch and number of compute units from rocminfo +echo -n "GPU_arch: " >> ${gemm_log}; rocminfo | grep "Name:" | grep "gfx" >> ${gemm_log} +rocminfo | grep "Compute Unit:" >> ${gemm_log} +hipcc --version | grep -e 'HIP version' >> ${gemm_log} +/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log} +./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log} +./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a $gemm_log + +python3 parse_perf_data.py ${gemm_log} + +#run resnet50 test +export resnet_log="perf_resnet50.log" +rm -f $resnet_log +git status | grep -e 'On branch' > ${resnet_log} +echo -n 'Node name: '>>${resnet_log}; hostname >>${resnet_log} +#get GPU_arch and number of compute units from rocminfo +echo -n "GPU_arch: " >> ${resnet_log}; rocminfo | grep "Name:" | grep "gfx" >> ${resnet_log} +rocminfo | grep "Compute Unit:" >> ${resnet_log} +hipcc --version | grep -e 'HIP version' >> ${resnet_log} +/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log} +#first run tests with N=256 +./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log} +#then run with N=4 +./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log} +#the script will put the results from N=256 and N=4 runs into separate tables +python3 parse_perf_data.py ${resnet_log} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7815d9a8ac0146d58d87c9b80a6d0a03b6a7b417..c44af34aac0fd418256d1175608766f4a480872b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,31 +1,5 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/ - ${PROJECT_SOURCE_DIR}/include/ck - ${PROJECT_SOURCE_DIR}/include/ck/utility - ${PROJECT_SOURCE_DIR}/include/ck/host_utility - ${PROJECT_SOURCE_DIR}/include/ck/tensor_description - ${PROJECT_SOURCE_DIR}/include/ck/tensor - ${PROJECT_SOURCE_DIR}/include/ck/problem_transform - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/device - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/grid - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/block - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/thread - ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/cpu/element - ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor - ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance - ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu - ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility - ${PROJECT_SOURCE_DIR}/test/include - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/external/include/half ) include(googletest) @@ -39,6 +13,7 @@ function(add_test_executable TEST_NAME) add_test(NAME ${TEST_NAME} COMMAND $ ) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) endfunction(add_test_executable TEST_NAME) include(GoogleTest) @@ -52,6 +27,7 @@ function(add_gtest_executable TEST_NAME) target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(${TEST_NAME} PRIVATE gtest_main) gtest_discover_tests(${TEST_NAME}) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) endfunction(add_gtest_executable TEST_NAME) @@ -68,7 +44,12 @@ add_subdirectory(grouped_gemm) add_subdirectory(convnd_fwd) add_subdirectory(reduce) add_subdirectory(conv2d_bwd_weight) +add_subdirectory(convnd_bwd_weight) add_subdirectory(convnd_bwd_data) add_subdirectory(block_to_ctile_map) +<<<<<<< HEAD add_subdirectory(cpu_ukernel) # DONOT add client_app, that is tested via CI independently +======= +add_subdirectory(softmax) +>>>>>>> origin/develop diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp index c039e344d29e2c21b029b28150c52f4be086b5ff..7fc1f24f5fd334f0bc2d79193dd981bdc4038a62 100644 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -1,6 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include -#include "profile_batched_gemm_impl.hpp" +#include "profiler/include/profile_batched_gemm_impl.hpp" namespace { using ADataType = ck::half_t; @@ -22,19 +25,19 @@ int main() pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, BatchCount); + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, BatchCount); + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, BatchCount); + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, BatchCount); + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_util.hpp b/test/batched_gemm/batched_gemm_util.hpp deleted file mode 100644 index 0a5c471d40161762509d113fc5ca5a67d671579f..0000000000000000000000000000000000000000 --- a/test/batched_gemm/batched_gemm_util.hpp +++ /dev/null @@ -1,106 +0,0 @@ -#ifndef BATCHED_GEMM_UTILS_HPP -#define BATCHED_GEMM_UTILS_HPP - -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" - -namespace ck { -namespace batched_gemm_util { - -struct GemmParams -{ - GemmParams() - : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) - { - } - - ck::index_t M; - ck::index_t N; - ck::index_t K; - - ck::index_t StrideA; - ck::index_t StrideB; - ck::index_t StrideC; - - float alpha; - float beta; -}; - -template -void RunHostBatchedGemm(const Tensor& A, - const Tensor& B, - Tensor& C, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) -{ - auto ref_batched_gemm = BatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); - - auto ref_argument = - ref_batched_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); -} - -template -void RunDeviceBatchedGemm(DeviceGemmPtr& batched_gemm_ptr, - const ck::batched_gemm_util::GemmParams& params, - const Tensor& A, - const Tensor& B, - Tensor& C, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) -{ - DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); - DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); - DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - - a_g_m_k_device_buf.ToDevice(A.mData.data()); - b_g_k_n_device_buf.ToDevice(B.mData.data()); - - const auto batch_count = A.mDesc.GetLengths()[0]; - auto invoker_ptr = batched_gemm_ptr->MakeInvokerPointer(); - auto argument_ptr = batched_gemm_ptr->MakeArgumentPointer( - static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_g_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_g_m_n_device_buf.GetDeviceBuffer()), - params.M, - params.N, - params.K, - params.StrideA, - params.StrideB, - params.StrideC, - a_element_op, - b_element_op, - c_element_op, - batch_count); - - if(!batched_gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - invoker_ptr->Run(argument_ptr.get()); - c_g_m_n_device_buf.FromDevice(C.mData.data()); -} - -} // namespace batched_gemm_util -} // namespace ck -#endif diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 3ecf19491be09f53301da8121e634b759cccb27b..fa1a2bf87f37e6911b8726565a9b8150427c4165 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,9 +1,3 @@ -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/test/include - ${PROJECT_SOURCE_DIR}/external/include/half -) - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE host_tensor) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp index 7b311cff170f72944247ef8a8cba041ece0a989e..456d21142fd613b06165d87e83d955584b747f70 100644 --- a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -1,6 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include -#include "profile_batched_gemm_reduce_impl.hpp" +#include "profiler/include/profile_batched_gemm_reduce_impl.hpp" int main() { diff --git a/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/test/block_to_ctile_map/test_block_to_ctile_map.cpp index 662d2a0fa57ee060509bed046303859f1197a218..55d9b59f489203cf68b5aebc1c5d9642801435c0 100644 --- a/test/block_to_ctile_map/test_block_to_ctile_map.cpp +++ b/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -1,8 +1,12 @@ -#include -#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "gtest/gtest.h" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" using namespace ck; diff --git a/test/client_app/CMakeLists.txt b/test/client_app/CMakeLists.txt deleted file mode 100644 index f8dd8c4e0ad288bc1bdbb28d5210bfcc9f04c8e7..0000000000000000000000000000000000000000 --- a/test/client_app/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -cmake_minimum_required(VERSION 3.15) -project(ck_app) -add_compile_options(-std=c++14) - -find_package(composable_kernel 1.0.0 COMPONENTS device_operations host_tensor) -find_package(hip REQUIRED PATHS /opt/rocm) -message(STATUS "Build with HIP ${hip_VERSION}") - -add_executable(test_client_app client_app.cpp) - -target_link_libraries(test_client_app PRIVATE composable_kernel::device_operations composable_kernel::host_tensor hip::host) diff --git a/test/client_app/client_app.cpp b/test/client_app/client_app.cpp deleted file mode 100644 index 665a103f706fa3edea340ce32eec591dc0d0f59f..0000000000000000000000000000000000000000 --- a/test/client_app/client_app.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "client_app_impl.hpp" - -int main(int argc, char* argv[]) -{ - if(argc != 25) - { - printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); - printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); - printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); - printf("arg6: verification (0: no; 1: yes)\n"); - printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: time kernel (0=n0, 1=yes)\n"); - printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - const ConvDataType data_type = static_cast(std::stoi(argv[2])); - const int in_layout = static_cast(std::stoi(argv[3])); - const int wei_layout = static_cast(std::stoi(argv[4])); - const int out_layout = static_cast(std::stoi(argv[5])); - const bool do_verification = std::stoi(argv[6]); - const int init_method = std::stoi(argv[7]); - const bool do_log = std::stoi(argv[8]); - const bool time_kernel = std::stoi(argv[9]); - - const ck::index_t N = std::stoi(argv[10]); - const ck::index_t K = std::stoi(argv[11]); - const ck::index_t C = std::stoi(argv[12]); - const ck::index_t Y = std::stoi(argv[13]); - const ck::index_t X = std::stoi(argv[14]); - const ck::index_t Hi = std::stoi(argv[15]); - const ck::index_t Wi = std::stoi(argv[16]); - - const ck::index_t conv_stride_h = std::stoi(argv[17]); - const ck::index_t conv_stride_w = std::stoi(argv[18]); - const ck::index_t conv_dilation_h = std::stoi(argv[19]); - const ck::index_t conv_dilation_w = std::stoi(argv[20]); - const ck::index_t in_left_pad_h = std::stoi(argv[21]); - const ck::index_t in_left_pad_w = std::stoi(argv[22]); - const ck::index_t in_right_pad_h = std::stoi(argv[23]); - const ck::index_t in_right_pad_w = std::stoi(argv[24]); - - const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; - const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - ck::app::profile_conv_fwd_impl(do_verification, - init_method, - do_log, - time_kernel, - data_type, - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - std::vector{conv_stride_h, conv_stride_w}, - std::vector{conv_dilation_h, conv_dilation_w}, - std::vector{in_left_pad_h, in_left_pad_w}, - std::vector{in_right_pad_h, in_right_pad_w}); - return 1; -} diff --git a/test/client_app/client_app_impl.hpp b/test/client_app/client_app_impl.hpp deleted file mode 100644 index f9e4145ba010f0502422efa5163ebe994096a674..0000000000000000000000000000000000000000 --- a/test/client_app/client_app_impl.hpp +++ /dev/null @@ -1,214 +0,0 @@ -#pragma once - -#include "host_interface.hpp" - -enum ConvDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 -}; - -enum ConvInputLayout -{ - NCHW, // 0 - NHWC, // 1 -}; - -enum ConvWeightLayout -{ - KCYX, // 0 - KYXC, // 1 -}; - -enum ConvOutputLayout -{ - NKHW, // 0 - NHWK, // 1 -}; - -void check_hip_error(void) -{ - hipError_t err = hipGetLastError(); - if(err != hipSuccess) - { - std::cerr << "Error: " << hipGetErrorString(err) << std::endl; - exit(err); - } -} -std::string getDeviceName(int device) -{ - struct hipDeviceProp_t prop; - hipGetDeviceProperties(&prop, device); - check_hip_error(); - return std::string(prop.name); -} - -int getDriver(void) -{ - int driver; - hipDriverGetVersion(&driver); - check_hip_error(); - return driver; -} - -namespace ck { -namespace app { -struct DeviceMem -{ - DeviceMem() = delete; - DeviceMem(std::size_t mem_size); - void* GetDeviceBuffer(); - void ToDevice(const void* p); - void FromDevice(void* p); - ~DeviceMem(); - - void* mpDeviceBuf; - std::size_t mMemSize; -}; - -DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) -{ - hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); -} - -void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } - -void DeviceMem::ToDevice(const void* p) -{ - hipGetErrorString( - hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); -} - -void DeviceMem::FromDevice(void* p) -{ - hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); -} - -DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } - -void profile_conv_fwd_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - ConvDataType data_type, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) -{ - const ck::index_t Y = filter_spatial_lengths[0]; - const ck::index_t X = filter_spatial_lengths[1]; - - const ck::index_t Hi = input_spatial_lengths[0]; - const ck::index_t Wi = input_spatial_lengths[1]; - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - const auto in_sz = N * C * Hi * Wi; - const auto wei_sz = K * C * Y * X; - const auto out_sz = N * K * Ho * Wo; - - using WeiDataType = float; - using InDataType = float; - using OutDataType = float; - - app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz); - app::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz); - app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz); - // data is already on device! - - // add device Conv instances - std::vector conv_ptrs; - if(data_type == F16_F16_F16) - { - add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs); - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs); - } - else if(data_type == BF16_BF16_BF16) - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(conv_ptrs); - else if(data_type == F32_F32_F32) - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs); - else if(data_type == INT8_INT8_INT8) - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(conv_ptrs); - else - throw std::runtime_error("wrong! Invalid data type"); - if(conv_ptrs.empty()) - { - throw std::runtime_error("wrong! no device Conv instance found"); - } - - std::string best_conv_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - int deviceIndex = 0; - hipSetDevice(deviceIndex); - check_hip_error(); - - StreamConfig stream_config{nullptr, time_kernel}; - hipStreamCreate(&stream_config.stream_id_); - check_hip_error(); - - // profile device Conv instances - for(auto& conv_ptr : conv_ptrs) - { - auto argument_ptr = - conv_ptr.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - auto invoker_ptr = conv_ptr.MakeInvokerPointer(); - - if(conv_ptr.IsSupportedArgument(argument_ptr.get())) - { - std::string conv_name = conv_ptr.GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << conv_name << std::endl; - - if(tflops > best_tflops) - { - best_conv_name = conv_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; -} - -} // namespace app -} // namespace ck diff --git a/test/conv2d_bwd_data/conv2d_bwd_data.cpp b/test/conv2d_bwd_data/conv2d_bwd_data.cpp index c8eb5413dccc93fd731d0779393cd0867741c1f4..cb9245387ab32e5a5f5ed791de4c5868318b38e0 100644 --- a/test/conv2d_bwd_data/conv2d_bwd_data.cpp +++ b/test/conv2d_bwd_data/conv2d_bwd_data.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -17,7 +20,7 @@ using INT8 = int8_t; namespace ck { namespace tensor_operation { namespace device { -namespace device_conv2d_bwd_data_instance { +namespace instance { using DeviceConvBwdDataNoOpPtr = DeviceConvBwdDataPtr&); -} // namespace device_conv2d_bwd_data_instance +} // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -217,28 +220,28 @@ int main(int argc, char* argv[]) ck::is_same_v, float> && ck::is_same_v, float>) { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); } else if constexpr(ck::is_same_v, ck::half_t> && ck::is_same_v, ck::half_t> && ck::is_same_v, ck::half_t>) { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); } else if constexpr(ck::is_same_v, ck::bhalf_t> && ck::is_same_v, ck::bhalf_t> && ck::is_same_v, ck::bhalf_t>) { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); } else if constexpr(ck::is_same_v, int8_t> && ck::is_same_v, int8_t> && ck::is_same_v, int8_t>) { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + ck::tensor_operation::device::instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); } diff --git a/test/conv2d_bwd_weight/CMakeLists.txt b/test/conv2d_bwd_weight/CMakeLists.txt index ecd5336c1f3bd61f8f578cdbad18f0aa2935b226..e61c9299c8cdc32769f7044946fd6dc7c4427717 100644 --- a/test/conv2d_bwd_weight/CMakeLists.txt +++ b/test/conv2d_bwd_weight/CMakeLists.txt @@ -1,7 +1,2 @@ -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/external/include/half -) - add_test_executable(test_conv2d_bwd_weight conv2d_bwd_weight.cpp) target_link_libraries(test_conv2d_bwd_weight PRIVATE host_tensor device_conv2d_bwd_weight_instance conv_util) diff --git a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp index 671980f49e402141cae8b3eb20217c52d388644d..7af0fa3d8279854a5b487d86275a5214fc48cf17 100644 --- a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp +++ b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp @@ -1,13 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include #include -#include "conv_util.hpp" -#include "profile_conv_bwd_weight_impl.hpp" +#include "test/convnd_fwd/conv_util.hpp" +#include "profiler/include/profile_conv_bwd_weight_impl.hpp" int test_self() { diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 98f55b872e21aee3e485258453537d13b5ccb7a7..293d94542cf9089e63e081cfaca2352b9fd22b89 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -1,203 +1,207 @@ -#include -#include -#include -#include - -#include "config.hpp" -#include "conv_util.hpp" -#include "tensor_layout.hpp" -#include "check_err.hpp" - -namespace { - -class TestConvUtil : public ::testing::Test -{ - public: - void SetNDParams(std::size_t ndims) - { - conv_params.num_dim_spatial_ = ndims; - conv_params.filter_spatial_lengths_ = std::vector(ndims, 3); - conv_params.input_spatial_lengths_ = std::vector(ndims, 71); - conv_params.conv_filter_strides_ = std::vector(ndims, 2); - conv_params.conv_filter_dilations_ = std::vector(ndims, 1); - conv_params.input_left_pads_ = std::vector(ndims, 1); - conv_params.input_right_pads_ = std::vector(ndims, 1); - } - - protected: - // ------- default 2D ------- - // input NCHW {128,192,71,71}, - // weights KCYX {256,192,3,3}, - // stride {2,2}, - // dilations {1,1}, - // padding {{1,1}, {1,1}} - ck::utils::conv::ConvParams conv_params; -}; - -} // namespace - -TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) -{ - ck::utils::conv::ConvParams conv_params; - std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{36, 36}, - "Error: ConvParams 2D default constructor.")); - - conv_params.conv_filter_strides_ = std::vector{1, 1}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); - - conv_params.conv_filter_strides_ = std::vector{2, 2}; - conv_params.input_left_pads_ = std::vector{2, 2}; - conv_params.input_right_pads_ = std::vector{2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{37, 37}, - "Error: ConvParams 2D padding left/right {2,2}.")); - - conv_params.conv_filter_dilations_ = std::vector{2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); - - conv_params.conv_filter_strides_ = std::vector{3, 3}; - conv_params.input_left_pads_ = std::vector{1, 1}; - conv_params.input_right_pads_ = std::vector{1, 1}; - conv_params.conv_filter_dilations_ = std::vector{2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE( - ck::utils::check_err(out_spatial_len, - std::vector{23, 23}, - "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); -} - -TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) -{ - SetNDParams(1); - - std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); - - conv_params.conv_filter_strides_ = std::vector{1}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); - - conv_params.conv_filter_strides_ = std::vector{2}; - conv_params.input_left_pads_ = std::vector{2}; - conv_params.input_right_pads_ = std::vector{2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{37}, - "Error: ConvParams 1D padding left/right {2}.")); - - conv_params.conv_filter_dilations_ = std::vector{2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); - - conv_params.conv_filter_strides_ = std::vector{3}; - conv_params.input_left_pads_ = std::vector{1}; - conv_params.input_right_pads_ = std::vector{1}; - conv_params.conv_filter_dilations_ = std::vector{2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE( - ck::utils::check_err(out_spatial_len, - std::vector{23}, - "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); -} - -TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) -{ - SetNDParams(3); - - std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); - - conv_params.conv_filter_strides_ = std::vector{1, 1, 1}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{71, 71, 71}, - "Error: ConvParams 3D stride {1, 1, 1}.")); - - conv_params.conv_filter_strides_ = std::vector{2, 2, 2}; - conv_params.input_left_pads_ = std::vector{2, 2, 2}; - conv_params.input_right_pads_ = std::vector{2, 2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{37, 37, 37}, - "Error: ConvParams 3D padding left/right {2, 2, 2}.")); - - conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{36, 36, 36}, - "Error: ConvParams 3D dilation {2, 2, 2}.")); - - conv_params.conv_filter_strides_ = std::vector{3, 3, 3}; - conv_params.input_left_pads_ = std::vector{1, 1, 1}; - conv_params.input_right_pads_ = std::vector{1, 1, 1}; - conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, - std::vector{23, 23, 23}, - "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); -} - -TEST(ConvUtil, GetHostTensorDescriptor) -{ - namespace tl = ck::tensor_layout::convolution; - std::vector dims{2, 3, 4, 5}; - HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); - EXPECT_TRUE(ck::utils::check_err( - h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!")); - EXPECT_TRUE(ck::utils::check_err( - h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!")); - - h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{}); - EXPECT_TRUE(ck::utils::check_err( - h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!")); - EXPECT_TRUE(ck::utils::check_err( - h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!")); - - dims = std::vector{2, 3, 4}; - h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); - EXPECT_TRUE( - ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!")); - EXPECT_TRUE(ck::utils::check_err( - h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!")); - - h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{}); - EXPECT_TRUE( - ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!")); - EXPECT_TRUE(ck::utils::check_err( - h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!")); - - dims = std::vector{2, 3, 4, 5, 6}; - h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); - EXPECT_TRUE( - ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!")); - EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), - {3 * 4 * 5 * 6, // N - 1, // C - 3 * 5 * 6, // D - 3 * 6, // H - 3}, // W - "Error: wrong NDHWC dimensions strides!")); - - h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{}); - EXPECT_TRUE( - ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!")); - EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), - {3 * 4 * 5 * 6, // N - 4 * 5 * 6, // C - 5 * 6, // D - 6, // H - 1}, // W - "Error: wrong NCDHW dimensions strides!")); -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" + +namespace { + +class TestConvUtil : public ::testing::Test +{ + public: + void SetNDParams(std::size_t ndims) + { + conv_params.num_dim_spatial_ = ndims; + conv_params.filter_spatial_lengths_ = std::vector(ndims, 3); + conv_params.input_spatial_lengths_ = std::vector(ndims, 71); + conv_params.conv_filter_strides_ = std::vector(ndims, 2); + conv_params.conv_filter_dilations_ = std::vector(ndims, 1); + conv_params.input_left_pads_ = std::vector(ndims, 1); + conv_params.input_right_pads_ = std::vector(ndims, 1); + } + + protected: + // ------- default 2D ------- + // input NCHW {128,192,71,71}, + // weights KCYX {256,192,3,3}, + // stride {2,2}, + // dilations {1,1}, + // padding {{1,1}, {1,1}} + ck::utils::conv::ConvParams conv_params; +}; + +} // namespace + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) +{ + ck::utils::conv::ConvParams conv_params; + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor.")); + + conv_params.conv_filter_strides_ = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); + + conv_params.conv_filter_strides_ = std::vector{2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); + + conv_params.conv_filter_strides_ = std::vector{3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) +{ + SetNDParams(1); + + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); + + conv_params.conv_filter_strides_ = std::vector{1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); + + conv_params.conv_filter_strides_ = std::vector{2}; + conv_params.input_left_pads_ = std::vector{2}; + conv_params.input_right_pads_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); + + conv_params.conv_filter_strides_ = std::vector{3}; + conv_params.input_left_pads_ = std::vector{1}; + conv_params.input_right_pads_ = std::vector{1}; + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) +{ + SetNDParams(3); + + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); + + conv_params.conv_filter_strides_ = std::vector{1, 1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}.")); + + conv_params.conv_filter_strides_ = std::vector{2, 2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}.")); + + conv_params.conv_filter_strides_ = std::vector{3, 3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, + std::vector{23, 23, 23}, + "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); +} + +TEST(ConvUtil, GetHostTensorDescriptor) +{ + namespace tl = ck::tensor_layout::convolution; + std::vector dims{2, 3, 4, 5}; + HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); + EXPECT_TRUE(ck::utils::check_err( + h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{}); + EXPECT_TRUE(ck::utils::check_err( + h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!")); + + dims = std::vector{2, 3, 4}; + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!")); + + dims = std::vector{2, 3, 4, 5, 6}; + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 1, // C + 3 * 5 * 6, // D + 3 * 6, // H + 3}, // W + "Error: wrong NDHWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 4 * 5 * 6, // C + 5 * 6, // D + 6, // H + 1}, // W + "Error: wrong NCDHW dimensions strides!")); +} diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index 55d71a41d325c4a673032eaf01c2440e4f19cbb0..554bcd18fbb09a4d6c9e35e08c9f20c3abda412a 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,7 +1,2 @@ -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/external/include/half -) - add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp) target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_util) diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp index 7284680e0e501e49049a2293b64c305b01cb6111..a5b83b9eed8552277be0be5dbfeb5aa2c3bbb8ae 100644 --- a/test/convnd_bwd_data/convnd_bwd_data.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -1,12 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include #include -#include "profile_convnd_bwd_data_impl.hpp" +#include "profiler/include/profile_convnd_bwd_data_impl.hpp" int main() { diff --git a/test/convnd_bwd_weight/CMakeLists.txt b/test/convnd_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e76c28bf4f36dd4a2fc7f2bd323c7ac922bd8e5f --- /dev/null +++ b/test/convnd_bwd_weight/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_convnd_bwd_weight convnd_bwd_weight.cpp) +target_link_libraries(test_convnd_bwd_weight PRIVATE host_tensor device_convnd_bwd_weight_instance conv_util) diff --git a/test/convnd_bwd_weight/convnd_bwd_weight.cpp b/test/convnd_bwd_weight/convnd_bwd_weight.cpp new file mode 100644 index 0000000000000000000000000000000000000000..febcef16c086ce5339419acb20d538165be661df --- /dev/null +++ b/test/convnd_bwd_weight/convnd_bwd_weight.cpp @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "test/convnd_fwd/conv_util.hpp" +#include "profiler/include/profile_convnd_bwd_weight_impl.hpp" + +int test_self() +{ + bool pass = true; + std::vector params; + + params.push_back({1, 128, 256, 256, {1}, {7}, {2}, {1}, {0}, {0}}); + params.push_back({1, 128, 256, 256, {3}, {14}, {1}, {1}, {1}, {1}}); + params.push_back({1, 128, 256, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<1, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // bf16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<1, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + + // check 2d + params.clear(); + params.push_back({2, 128, 256, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + params.push_back({2, 128, 256, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + params.push_back({2, 128, 256, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // bf16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<2, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + + // check 2d + params.clear(); + params.push_back( + {3, 128, 256, 256, {1, 1, 1}, {4, 4, 4}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + params.push_back( + {3, 128, 256, 256, {3, 3, 3}, {4, 4, 8}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + params.push_back( + {3, 128, 256, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<3, + float, + float, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<3, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // bf16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<3, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + + return pass; +} +int main() +{ + // int data_type = 1; + // int init_method = 1; + + bool pass = true; + + pass = test_self(); + + if(pass) + { + std::cout << "test conv2d bwd weight : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test conv2d bwd weight: Fail " << std::endl; + return -1; + } +} diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 34e698681b2b7767f6b53ed70c9c3a76673114ae..444ec6c8aaa9fd7efe5c86ae007356ccb45ce02d 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -5,7 +5,7 @@ target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_inst add_dependencies(test_convnd_fwd test_conv1d_fwd) add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp) -target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_util) +target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance device_convnd_2d_fwd_instance conv_util) add_dependencies(test_convnd_fwd test_conv2d_fwd) add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp) diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp index b6b6a89b2ce447fc200b30697f448288c73f7d92..4d2473f020b1c4b5bcad1f363b697bb3b72b6e75 100644 --- a/test/convnd_fwd/conv1d_fwd.cpp +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -1,93 +1,192 @@ -#include -#include -#include -#include -#include "gtest/gtest.h" - -#include "data_type.hpp" -#include "element_wise_operation.hpp" -#include "library/include/ck/library/utility/conv_util.hpp" -#include "conv_util.hpp" - -namespace { - -template -bool test_conv1d_nwc_instances(const std::vector& conv_ptrs) -{ - using namespace std::placeholders; - using namespace ck::utils; - namespace ctl = ck::tensor_layout::convolution; - - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{71}; - params.conv_filter_strides_ = std::vector{2}; - params.conv_filter_dilations_ = std::vector{1}; - params.input_left_pads_ = std::vector{1}; - params.input_right_pads_ = std::vector{1}; - - conv::ConvFwdOpInstance conv_instance(params); - - auto reference_conv_fwd_fun = - std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv1DFwdNWC, TestConv1D) -{ - using namespace std::placeholders; - using namespace ck::utils; - namespace ctl = ck::tensor_layout::convolution; - - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{16}; - params.conv_filter_strides_ = std::vector{1}; - params.conv_filter_dilations_ = std::vector{1}; - params.input_left_pads_ = std::vector{1}; - params.input_right_pads_ = std::vector{1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance( - params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<1, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - run_engine.SetAtol(1e-5); - run_engine.SetRtol(1e-4); - EXPECT_TRUE(run_engine.Test(conv_ptrs)); -} - -TEST(Conv1DFwdNWC, Bf16Iinstances) -{ - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); -} - -TEST(Conv1DFwdNWC, F16Instances) -{ - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); -} - -TEST(Conv1DFwdNWC, F32Instances) -{ - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); -} - -TEST(Conv1DFwdNWC, Int8Instances) -{ - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "test/convnd_fwd/conv_util.hpp" + +namespace { + +class Conv1dFwdNWCInstances : public ::testing::Test +{ + public: + template + bool test_conv1d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 1, 4, 256, 64, {3}, {71}, {2}, {2}, {2}, {2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {1}, {1}, {0}, {0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {2}, {1}, {0}, {0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv1DFwdNWC, IntegerValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = float; + + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv1DFwdNWC, FloatingPointValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(0.1); + run_engine.SetRtol(1e-2); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST_F(Conv1dFwdNWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 05e46147be1a33604b9a364b1576db654c0b5a8d..f45805782c34db1f8161fcbab8abb83eba3e5650 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -1,91 +1,266 @@ -#include -#include -#include -#include -#include "gtest/gtest.h" - -#include "data_type.hpp" -#include "element_wise_operation.hpp" -#include "ck/library/utility/conv_util.hpp" -#include "conv_util.hpp" - -namespace { - -template -bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs) -{ - using namespace std::placeholders; - using namespace ck::utils; - - conv::ConvParams params; - params.num_dim_spatial_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3}; - params.input_spatial_lengths_ = std::vector{71, 71}; - params.conv_filter_strides_ = std::vector{2, 2}; - params.conv_filter_dilations_ = std::vector{1, 1}; - params.input_left_pads_ = std::vector{1, 1}; - params.input_right_pads_ = std::vector{1, 1}; - - conv::ConvFwdOpInstance conv_instance(params); - - auto reference_conv_fwd_fun = - std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv2DFwdNHWC, TestConv2D) -{ - using namespace std::placeholders; - using namespace ck::utils; - - ck::utils::conv::ConvParams params; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.input_spatial_lengths_ = std::vector{16, 16}; - params.conv_filter_strides_ = std::vector{1, 1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance(params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<2, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - run_engine.SetAtol(1e-5); - run_engine.SetRtol(1e-4); - EXPECT_TRUE(run_engine.Test(conv_ptrs)); -} - -TEST(Conv2DFwdNHWC, Bf16Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} - -TEST(Conv2DFwdNHWC, F16Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} - -TEST(Conv2DFwdNHWC, BF32Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} - -TEST(Conv2DFwdNHWC, F32Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} - -TEST(Conv2DFwdNHWC, Int8Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "test/convnd_fwd/conv_util.hpp" + +namespace { + +class Conv2dFwdNHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_default_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_default_); + } + } + + template + bool test_filter1x1_stride1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), + params_filter1x1_stride1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_stride1_pad0_); + } + } + + template + bool test_filter1x1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_filter1x1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_pad0_); + } + } + + template + bool test_oddC() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_oddC_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_oddC_{ + 2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv2DFwdNHWC, IntegerValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + using T = float; + + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv2DFwdNHWC, FloatingPointValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(2e-4); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC()); } +TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default) +{ + EXPECT_TRUE(this->test_default(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_default) +{ + EXPECT_TRUE(this->test_default(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index c6f0e7ec07fb5e638c9844324bb61712d48b5dc3..0cc2b2416eb379b3199a574b1ec1863fbb41bd83 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -1,214 +1,317 @@ -#include -#include -#include -#include -#include -#include "gtest/gtest.h" - -#include "data_type.hpp" -#include "element_wise_operation.hpp" -#include "library/include/ck/library/utility/conv_util.hpp" -#include "conv_util.hpp" - -namespace { - -template -bool test_conv3d_ndhwc_instances(const std::vector& conv_ptrs) -{ - using namespace std::placeholders; - using namespace ck::utils; - namespace ctl = ck::tensor_layout::convolution; - - conv::ConvParams params; - params.N_ = 64; - params.num_dim_spatial_ = 3; - params.filter_spatial_lengths_ = std::vector{3, 3, 2}; - params.input_spatial_lengths_ = std::vector{32, 32, 2}; - params.conv_filter_strides_ = std::vector{2, 2, 2}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{1, 1, 1}; - params.input_right_pads_ = std::vector{1, 1, 1}; - - conv::ConvFwdOpInstance conv_instance(params); - - auto reference_conv_fwd_fun = - std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv3DFwdNDHWC, TestConv3D) -{ - using namespace std::placeholders; - using namespace ck::utils; - namespace ctl = ck::tensor_layout::convolution; - - conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.filter_spatial_lengths_ = std::vector{3, 3, 3}; - params.input_spatial_lengths_ = std::vector{16, 16, 16}; - params.conv_filter_strides_ = std::vector{1, 1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{1, 1, 1}; - params.input_right_pads_ = std::vector{1, 1, 1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance( - params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<3, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - run_engine.SetAtol(1e-5); - run_engine.SetRtol(1e-4); - EXPECT_TRUE(run_engine.Test(conv_ptrs)); -} - -TEST(Conv3DFwdNDHWC, InputOver2GB) -{ - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using namespace ck::utils; - - // >2GB Input - conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 2; - params.K_ = 16; - params.C_ = 32; - params.filter_spatial_lengths_ = std::vector{3, 3, 3}; - params.input_spatial_lengths_ = std::vector{32, 1000, 1000}; - params.conv_filter_strides_ = std::vector{1, 1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{1, 1, 1}; - params.input_right_pads_ = std::vector{1, 1, 1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - - auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, - nullptr, - nullptr, - params.N_, - params.K_, - params.C_, - params.input_spatial_lengths_, - params.filter_spatial_lengths_, - params.GetOutputSpatialLengths(), - params.conv_filter_strides_, - params.conv_filter_dilations_, - params.input_left_pads_, - params.input_right_pads_, - PassThrough{}, - PassThrough{}, - PassThrough{}); - EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); -} - -TEST(Conv3DFwdNDHWC, FiltersOver2GB) -{ - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using namespace ck::utils; - - // >2GB Filters - conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 2; - params.K_ = 16; - params.C_ = 32; - params.filter_spatial_lengths_ = std::vector{4, 1000, 1000}; - params.input_spatial_lengths_ = std::vector{16, 16, 16}; - params.conv_filter_strides_ = std::vector{1, 1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{1, 1, 1}; - params.input_right_pads_ = std::vector{1, 1, 1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - - auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, - nullptr, - nullptr, - params.N_, - params.K_, - params.C_, - params.input_spatial_lengths_, - params.filter_spatial_lengths_, - params.GetOutputSpatialLengths(), - params.conv_filter_strides_, - params.conv_filter_dilations_, - params.input_left_pads_, - params.input_right_pads_, - PassThrough{}, - PassThrough{}, - PassThrough{}); - EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); -} - -TEST(Conv3DFwdNDHWC, OutputOver2GB) -{ - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using namespace ck::utils; - - // >2GB Output - conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 2; - params.K_ = 16; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{1, 1, 1}; - params.input_spatial_lengths_ = std::vector{1000, 1000, 30}; - params.conv_filter_strides_ = std::vector{1, 1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{2, 2, 2}; - params.input_right_pads_ = std::vector{2, 2, 2}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, - nullptr, - nullptr, - params.N_, - params.K_, - params.C_, - params.input_spatial_lengths_, - params.filter_spatial_lengths_, - params.GetOutputSpatialLengths(), - params.conv_filter_strides_, - params.conv_filter_dilations_, - params.input_left_pads_, - params.input_right_pads_, - PassThrough{}, - PassThrough{}, - PassThrough{}); - EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); -} - -TEST(Conv3DFwdNDHWC, Bf16Instances) -{ - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); -} - -TEST(Conv3DFwdNDHWC, F16Instances) -{ - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); -} - -TEST(Conv3DFwdNDHWC, F32Instances) -{ - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); -} - -TEST(Conv3DFwdNDHWC, Int8Instances) -{ - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/conv_util.hpp" + +#include "test/convnd_fwd/conv_util.hpp" + +namespace { + +class Conv3dFwdNDHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv3d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 3, 4, 256, 64, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv3DFwdNDHWC, IntegerValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = float; + + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv3DFwdNDHWC, FloatingPointValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-3); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv3DFwdNDHWC, InputOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + using T = float; + + // >2GB Input + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{32, 1000, 1000}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, FiltersOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + using T = float; + + // >2GB Filters + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{4, 1000, 1000}; + params.input_spatial_lengths_ = std::vector{16, 16, 16}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, OutputOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + using T = float; + + // >2GB Output + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{1, 1, 1}; + params.input_spatial_lengths_ = std::vector{1000, 1000, 30}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{2, 2, 2}; + params.input_right_pads_ = std::vector{2, 2, 2}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST_F(Conv3dFwdNDHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv3dFwdNDHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv3dFwdNDHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv3dFwdNDHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp index 09f641b415142d1e6738c1007ed3e959f289a494..c698bbd05c4e418072f4e0526da5fbbd00a3c7d3 100644 --- a/test/convnd_fwd/conv_util.hpp +++ b/test/convnd_fwd/conv_util.hpp @@ -1,13 +1,35 @@ -#ifndef TEST_CONV_UTIL_HPP -#define TEST_CONV_UTIL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "sequence.hpp" +#include "ck/ck.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; +namespace instance { + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck namespace test { namespace conv { @@ -25,57 +47,128 @@ using DeviceConvFwdNoOpPtr = static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -template +template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off InDataType, // WeiDataType, // OutDataType, // - InDataType, // + AccDataType, // Accumulator data type. InElementOp, // Input Elementwise Operation WeiElementOp, // Weights Elementwise Operation OutElementOp, // Output Elementwise Operation ConvFwdDefault, // ConvForwardSpecialization SpatialDims, // SptialDims - 64, // BlockSize - 16, // MPerBlock - 16, // NPerBlock + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock 4, // K0PerBlock - 1, // K1 - 16, // MPerXDL - 16, // NPerXDL - 1, // MXdlPerWave - 1, // NXdlPerWave - S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector - 1, // ABlockTransferDstScalarPerVector_K1 + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 true, // ABlockLdsAddExtraM - S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 1, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockTransferAddExtraN + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector + 1>; // CThreadTransferDstScalarPerVector // clang-format on template + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType> void get_test_convolution_fwd_instance(std::vector& instances) { - using ConvInstanceT = DeviceConvNDFwdInstance; + using ConvInstanceT = + DeviceConvNDFwdInstance; instances.emplace_back(std::make_unique()); } +// TODO (aosewski) +// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K +// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances +// structures. +template +struct ConvolutionNDFwdInstances; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + } // namespace conv } // namespace test - -#endif diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index b8679e37157e293a0152a5cce256ece06b565c18..83b3c1e2e3049cbed8a1a00653d054facf8c0c22 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,29 +1,15 @@ -# GEMM XDL -add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp) -target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_fp32 gemm_fp32.cpp) +target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp) -target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_fp16 gemm_fp16.cpp) +target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp) -target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor) -target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_bf16 gemm_bf16.cpp) +target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) +target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp) -target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor) -target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance) - -# GEMM DL -add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) -target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance) - -add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) -target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance) - -add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) -target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) -TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_int8 gemm_int8.cpp) +target_link_libraries(test_gemm_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7ecc892dcde7ec86d0a0d3a509340f50b57c892 --- /dev/null +++ b/test/gemm/gemm_bf16.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = ck::bhalf_t; + using BDataType = ck::bhalf_t; + using CDataType = ck::bhalf_t; + using AccDataType = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_fp16.cpp b/test/gemm/gemm_dl_fp16.cpp deleted file mode 100644 index 8a539372badc69063cd08afca5f00cdb5196b957..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_dl_fp16.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "../gemm/gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = ck::half_t; - using BDataType = ck::half_t; - using CDataType = ck::half_t; - using AccDataType = float; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - - std::vector gemmPtrs; - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_dl_fp32.cpp b/test/gemm/gemm_dl_fp32.cpp deleted file mode 100644 index 3484458042e00cd460e25443165cf5939c27f5c1..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_dl_fp32.cpp +++ /dev/null @@ -1,133 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "../gemm/gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = float; - using BDataType = float; - using CDataType = float; - using AccDataType = float; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_dl_int8.cpp b/test/gemm/gemm_dl_int8.cpp deleted file mode 100644 index 5dfb7221cb602326dcaa72d02bc6d7e59c3fffa5..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_dl_int8.cpp +++ /dev/null @@ -1,133 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "../gemm/gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_dl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = int8_t; - using BDataType = int8_t; - using CDataType = int8_t; - using AccDataType = int; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea9864abeb4b85b6082035fe7fa0767ba2e6bb21 --- /dev/null +++ b/test/gemm/gemm_fp16.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + using AccDataType = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b66addd71278d3ce133bf34aca27354f7dc65b7c --- /dev/null +++ b/test/gemm/gemm_fp32.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + using AccDataType = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_fp64.cpp b/test/gemm/gemm_fp64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0b9cab3707b604177f18859641812f21e04a7b9 --- /dev/null +++ b/test/gemm/gemm_fp64.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = double; + using BDataType = double; + using CDataType = double; + using AccDataType = double; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..972f4079752417781914275f0b8aa6207502c11b --- /dev/null +++ b/test/gemm/gemm_int8.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + using AccDataType = int32_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index a3cafa6df16c1aa52fd8a1e229e53630ade3cba9..4528c4aaeff5b8da1aeeed14b9a8137a3928b96c 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -1,13 +1,15 @@ -#ifndef GEMM_UTILS_HPP -#define GEMM_UTILS_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "reference_gemm.hpp" -#include "tensor_layout.hpp" +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { namespace gemm_util { @@ -157,7 +159,7 @@ struct TestGemm return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); } - auto operator()(DeviceGemmPtr_& gemmPtr) + auto operator()(const DeviceGemmPtr_& gemmPtr) { std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name << ", CLayout = " << CLayout{}.name << std::endl; @@ -212,6 +214,11 @@ struct TestGemm res = ck::utils::check_err(c_device.mData, c_host.mData); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } else if(std::is_same::value) { res = ck::utils::check_err(c_device.mData, c_host.mData); @@ -232,122 +239,5 @@ struct TestGemm } }; -template -struct TestGemmBF16 -{ - using BF16 = ck::bhalf_t; - - auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params) - { - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - // use fp32 host kernel to verify bf16 device kernel - Tensor a_m_k_bf16( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n_bf16( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_device_bf16( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - Tensor a_m_k_fp32( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n_fp32( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_host_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - - bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); - bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); - - return std::make_tuple(a_m_k_bf16, - b_k_n_bf16, - c_m_n_device_bf16, - a_m_k_fp32, - b_k_n_fp32, - c_m_n_host_fp32, - c_m_n_device_fp32); - } - - auto operator()(DeviceGemmPtr_& gemmPtr) - { - // Arrange - ck::gemm_util::GemmParams params; - params.M = 1024; - params.N = 1024; - params.K = 1024; - params.StrideA = 1024; - params.StrideB = 1024; - params.StrideC = 1024; - - auto host_tensors = PrepareGemmTensorBF16(params); - const Tensor& a_bf16 = std::get<0>(host_tensors); - const Tensor& b_bf16 = std::get<1>(host_tensors); - Tensor& c_device_bf16 = std::get<2>(host_tensors); - Tensor& a_fp32 = std::get<3>(host_tensors); - Tensor& b_fp32 = std::get<4>(host_tensors); - Tensor& c_host_fp32 = std::get<5>(host_tensors); - Tensor& c_device_fp32 = std::get<6>(host_tensors); - - auto a_element_op = AElementwiseOperation{}; - auto b_element_op = BElementwiseOperation{}; - auto c_element_op = CElementwiseOperation{}; - - // use fp32 host kernel to verify bf16 device kernel - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - ck::gemm_util::RunHostGEMM( - a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); - - // Act - ck::gemm_util::RunDeviceGEMM(gemmPtr, - params, - a_bf16, - b_bf16, - c_device_bf16, - a_element_op, - b_element_op, - c_element_op); - - bf16_to_f32_(c_device_bf16, c_device_fp32); - - // Assert - bool res = ck::utils::check_err( - c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res; - }; -}; - } // namespace gemm_util } // namespace ck -#endif diff --git a/test/gemm/gemm_xdl_bf16.cpp b/test/gemm/gemm_xdl_bf16.cpp deleted file mode 100644 index 5461088b0225a50a2f477aeeb00a315ec293e425..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_xdl_bf16.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector&); -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - std::vector gemmPtrs; - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_xdl_fp16.cpp b/test/gemm/gemm_xdl_fp16.cpp deleted file mode 100644 index 6fe3f83d1cd8d4d740f1d78e960fd26715828ecd..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_xdl_fp16.cpp +++ /dev/null @@ -1,160 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); - -void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector&); -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = ck::half_t; - using BDataType = ck::half_t; - using CDataType = ck::half_t; - using AccDataType = float; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_xdl_fp32.cpp b/test/gemm/gemm_xdl_fp32.cpp deleted file mode 100644 index 4756d1b4d6f6d0f6556fc0f92accddf14dfa60cb..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_xdl_fp32.cpp +++ /dev/null @@ -1,159 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); - -void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = float; - using BDataType = float; - using CDataType = float; - using AccDataType = float; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_xdl_fp64.cpp b/test/gemm/gemm_xdl_fp64.cpp deleted file mode 100644 index db37211505d185740a3d6dcb01855b2efcc6ab35..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_xdl_fp64.cpp +++ /dev/null @@ -1,156 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { -void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -inline std::string get_device_name() -{ - hipDeviceProp_t props{}; - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) - { - return std::string(); - } - - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) - { - return std::string(); - } - const std::string name(props.gcnArchName); - - return name; -} - -int main() -{ - if(get_device_name().find("gfx90a") == std::string::npos) - { - std::cout << "TestGemm ..... SUCCESS" << std::endl; - return 0; - } - using ADataType = double; - using BDataType = double; - using CDataType = double; - using AccDataType = double; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - bool res = true; - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm/gemm_xdl_int8.cpp b/test/gemm/gemm_xdl_int8.cpp deleted file mode 100644 index 0075b79cf7bd3e18823f0f76ba475382c3b6e4d6..0000000000000000000000000000000000000000 --- a/test/gemm/gemm_xdl_int8.cpp +++ /dev/null @@ -1,133 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = int8_t; - using BDataType = int8_t; - using CDataType = int8_t; - using AccDataType = int32_t; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - std::vector gemmPtrs; - bool res = true; - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index e474af3230118721f184c1b7b44806f1243bf7d8..74b787ac27ed54ec03c38ca135e8caeb10f54223 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,9 +1,3 @@ -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/test/include - ${PROJECT_SOURCE_DIR}/external/include/half -) - add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor) target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp index 6c7bb9658fd1a8076081c9f881ad1aefd0d6d46d..16f787e07e60123ef99b913fd2f096573330b22a 100644 --- a/test/gemm_reduce/gemm_reduce_fp16.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -1,6 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include -#include "profile_gemm_reduce_impl.hpp" +#include "profiler/include/profile_gemm_reduce_impl.hpp" int main() { diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index 40d422377bcb2b38d36373d74b896ba7f83f43f2..ab1d016c9d4281a5b8309ace510b07b9d8691076 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,3 +1,3 @@ add_test_executable(test_gemm_split_k gemm_split_k.cpp) target_link_libraries(test_gemm_split_k PRIVATE host_tensor) -target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance) +target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance) diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp index b63361aa1b24ea50dac6ce7732dd6f8db528dfe7..fa06d76e36c94fd7c09dab14a43a08feb0af6c45 100644 --- a/test/gemm_split_k/gemm_split_k.cpp +++ b/test/gemm_split_k/gemm_split_k.cpp @@ -1,16 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "host_gemm.hpp" -#include "tensor_layout.hpp" -#include "device_gemm_xdl_splitk.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/library/host_tensor/host_gemm.hpp" enum struct GemmMatrixLayout { @@ -20,26 +29,6 @@ enum struct GemmMatrixLayout KM_NK_MN, // 3 }; -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - template static bool check_out(const Tensor& ref, const Tensor& result) { @@ -71,6 +60,11 @@ struct gemmArgs int test_gemm(const gemmArgs& args) { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + bool a_row_major, b_row_major, c_row_major; switch(args.layout) @@ -141,64 +135,79 @@ int test_gemm(const gemmArgs& args) b_device_buf.ToDevice(b_k_n.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - // add device GEMM instances - std::vector gemm_ptrs; + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool success = false; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; + + const auto gemm_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + args.M, + args.N, + args.K, + args.StrideA, + args.StrideB, + args.StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + args.KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get()); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(!check_out(c_m_n_host_result, c_m_n_device_result)) + { + success = false; + break; + } + success = true; + } + } + + return success; + }; + + bool success = false; if(args.layout == GemmMatrixLayout::MK_KN_MN) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + success = test(Row{}, Row{}, Row{}); } else if(args.layout == GemmMatrixLayout::MK_NK_MN) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + success = test(Row{}, Col{}, Row{}); } else if(args.layout == GemmMatrixLayout::KM_KN_MN) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + success = test(Col{}, Row{}, Row{}); } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + success = test(Col{}, Col{}, Row{}); } - bool success = false; - for(auto& gemm_ptr : gemm_ptrs) - { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - args.M, - args.N, - args.K, - args.StrideA, - args.StrideB, - args.StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - args.KBatch); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get()); - - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(!check_out(c_m_n_host_result, c_m_n_device_result)) - { - success = false; - break; - } - success = true; - } - } auto error_code = 0; if(success) { diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index fc8ec66b51a98fb83ebad953e4f9a4e627a59cfd..5418ee02bde9de46e4779767e2ad2ddb4cc33c75 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -1,22 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_grouped_gemm_xdl.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -28,7 +28,7 @@ using DeviceGroupedGemmPtr_ = ck::tensor_operation::device::DeviceGroupedGemmPtr namespace ck { namespace tensor_operation { namespace device { -namespace device_grouped_gemm_instance { +namespace instance { void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( std::vector&); } @@ -197,7 +197,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) int main() { std::vector groupedGemmPtrs; - ck::tensor_operation::device::device_grouped_gemm_instance:: + ck::tensor_operation::device::instance:: add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(groupedGemmPtrs); bool res = true; diff --git a/test/magic_number_division/magic_number_division.cpp b/test/magic_number_division/magic_number_division.cpp index 751a62be19951f1ed7386e0b5b381dcfc8088f34..79811416080f03af790b4b4dc0fdc1969ba4afae 100644 --- a/test/magic_number_division/magic_number_division.cpp +++ b/test/magic_number_division/magic_number_division.cpp @@ -1,17 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "magic_division.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" + +#include "ck/ck.hpp" +#include "ck/utility/magic_division.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" __global__ void gpu_magic_number_division(uint32_t magic_multiplier, uint32_t magic_shift, diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp index 20030392b5a36fa940b1c982a8b345c398964141..843a6b110a7b133621d41f7d8544d720cfa089b8 100644 --- a/test/reduce/reduce_no_index.cpp +++ b/test/reduce/reduce_no_index.cpp @@ -1,7 +1,10 @@ -#include "getopt.h" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "host_common_util.hpp" -#include "profile_reduce_impl.hpp" +#include + +#include "ck/library/host_tensor/host_common_util.hpp" +#include "profiler/include/profile_reduce_impl.hpp" using namespace ck; diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index c1918bf38868c2d0931d3cd58978cb0652bf8573..64f16b80857b908c521d2112d89f9d8fc127e016 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -1,7 +1,10 @@ -#include "getopt.h" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "host_common_util.hpp" -#include "profile_reduce_impl.hpp" +#include + +#include "ck/library/host_tensor/host_common_util.hpp" +#include "profiler/include/profile_reduce_impl.hpp" using namespace ck; diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index 69b223989fde46adf0ca59d58cf124a072f42abc..2b5591675f4b2791fabed0a50f39758b11267c5b 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -1,389 +1,392 @@ -#include -#include -#include -#include -#include -#include -#include "gtest/gtest.h" - -#include "check_err.hpp" -#include "config.hpp" -#include "conv_util.hpp" -#include "element_wise_operation.hpp" -#include "fill.hpp" -#include "host_tensor.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" - -namespace { -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -template , - typename FillWeightsOp = ck::utils::FillConstant> -Tensor -run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, - const FillInputOp& fill_input_op = FillInputOp{}, - const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) -{ - std::vector input_dims{static_cast(params.N_), - static_cast(params.C_)}; - input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths_), - std::end(params.input_spatial_lengths_)); - - std::vector filter_dims{static_cast(params.K_), - static_cast(params.C_)}; - filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths_), - std::end(params.filter_spatial_lengths_)); - - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N_), - static_cast(params.K_)}; - output_dims.insert(std::end(output_dims), - std::begin(output_spatial_lengths), - std::end(output_spatial_lengths)); - - Tensor input(ck::utils::conv::get_host_tensor_descriptor(input_dims, InLayout{})); - Tensor weights( - ck::utils::conv::get_host_tensor_descriptor(filter_dims, WeiLayout{})); - Tensor host_output( - ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); - - fill_input_op(input.begin(), input.end()); - fill_weights_op(weights.begin(), weights.end()); - std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); - - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weights, - host_output, - params.conv_filter_strides_, - params.conv_filter_dilations_, - params.input_left_pads_, - params.input_right_pads_, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); - return host_output; -} - -} // anonymous namespace - -TEST(ReferenceConvolutionFWD, Conv2DNHWC) -{ - ck::utils::conv::ConvParams params; - params.N_ = 1; - params.K_ = 1; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3}; - params.input_spatial_lengths_ = std::vector{6, 6}; - params.conv_filter_strides_ = std::vector{1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1}; - params.input_left_pads_ = std::vector{0, 0}; - params.input_right_pads_ = std::vector{0, 0}; - - auto out_tensor = run_reference_convolution_forward<2>(params); - std::vector ref_dims{1, 1, 4, 4}; - std::vector ref_data{130.5, - 148.5, - 166.5, - 184.5, - 238.5, - 256.5, - 274.5, - 292.5, - 346.5, - 364.5, - 382.5, - 400.5, - 454.5, - 472.5, - 490.5, - 508.5}; - EXPECT_TRUE(ck::utils::check_err( - out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); - EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); -} - -TEST(ReferenceConvolutionFWD, Conv2DNHWCStridesDilationsPadding) -{ - ck::utils::conv::ConvParams params; - params.N_ = 1; - params.K_ = 2; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3}; - params.input_spatial_lengths_ = std::vector{12, 12}; - params.conv_filter_strides_ = std::vector{2, 2}; - params.conv_filter_dilations_ = std::vector{2, 2}; - params.input_left_pads_ = std::vector{1, 1}; - params.input_right_pads_ = std::vector{1, 1}; - - auto out_tensor = run_reference_convolution_forward<2>(params); - std::vector ref_dims = std::vector{1, 2, 5, 5}; - std::vector ref_data{ - 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., - 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, - 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, - 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, - 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; - EXPECT_TRUE(ck::utils::check_err( - out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); - EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); -} - -TEST(ReferenceConvolutionFWD, Conv1DNWC) -{ - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.N_ = 1; - params.K_ = 1; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{6}; - params.conv_filter_strides_ = std::vector{1}; - params.conv_filter_dilations_ = std::vector{1}; - params.input_left_pads_ = std::vector{0}; - params.input_right_pads_ = std::vector{0}; - - auto out_tensor = - run_reference_convolution_forward<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>(params); - std::vector ref_dims{1, 1, 4}; - std::vector ref_data{7.5, 13.5, 19.5, 25.5}; - EXPECT_TRUE(ck::utils::check_err( - out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); - EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); -} - -TEST(ReferenceConvolutionFWD, Conv1DNWCStridesDilationsPadding) -{ - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.N_ = 1; - params.K_ = 2; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{12}; - params.conv_filter_strides_ = std::vector{2}; - params.conv_filter_dilations_ = std::vector{2}; - params.input_left_pads_ = std::vector{1}; - params.input_right_pads_ = std::vector{1}; - - auto out_tensor = - run_reference_convolution_forward<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>(params); - std::vector ref_dims{1, 2, 5}; - std::vector ref_data{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; - EXPECT_TRUE(ck::utils::check_err( - out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); - EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); -} - -TEST(ReferenceConvolutionFWD, Conv1DNWCSameOutputSize) -{ - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{16}; - params.conv_filter_strides_ = std::vector{1}; - params.conv_filter_dilations_ = std::vector{1}; - params.input_left_pads_ = std::vector{1}; - params.input_right_pads_ = std::vector{1}; - - auto out_tensor2 = run_reference_convolution_forward<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>( - params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); - - std::vector ref_dims{2, 16, 16}; - std::vector ref_data{ - 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, - 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, - 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, - 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, - 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, - 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, - 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, - 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, - 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, - 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, - 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, - 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, - 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, - 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, - 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, - 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, - 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, - 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, - 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, - 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, - 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, - 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, - 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, - 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, - 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, - 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, - 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, - 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, - 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, - 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, - 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, - 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, - 27., 27., 27., 27., 27., 27., 27., 27., - 27., 27., 27., 27., 27., 27., 27., 27., - 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, - 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, - 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, - 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, - 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, - 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, - 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, - 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, - 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, - 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, - 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, - 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, - 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, - 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, - 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, - 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, - 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, - 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, - 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, - 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, - 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, - 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, - 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, - 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, - 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, - 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, - 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, - 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, - 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, - 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; - EXPECT_TRUE(ck::utils::check_err( - out_tensor2.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); - EXPECT_TRUE(ck::utils::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!")); -} - -TEST(ReferenceConvolutionFWD, Conv3DNCDHW) -{ - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 1; - params.K_ = 1; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3, 3}; - params.input_spatial_lengths_ = std::vector{6, 6, 6}; - params.conv_filter_strides_ = std::vector{1, 1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{0, 0, 0}; - params.input_right_pads_ = std::vector{0, 0, 0}; - - auto out_tensor = run_reference_convolution_forward<3, - float, - float, - float, - ck::tensor_layout::convolution::NCDHW, - ck::tensor_layout::convolution::KCZYX, - ck::tensor_layout::convolution::NKDHW>( - params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); - std::vector ref_dims{1, 1, 4, 4, 4}; - std::vector ref_data{ - 407.7, 410.40002, 413.09998, 415.80002, 423.90002, 426.6, 429.30002, 432., - 440.1, 442.80002, 445.5, 448.2, 456.30002, 459., 461.7, 464.40002, - 504.90002, 507.6, 510.30002, 513., 521.1, 523.8, 526.5, 529.2001, - 537.3, 540., 542.7001, 545.4, 553.5, 556.2001, 558.9, 561.6, - 602.10004, 604.8, 607.5, 610.2, 618.3, 621., 623.7, 626.4, - 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, - 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, - 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; - EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error [case 1]: wrong output tensor dimensions!")); - EXPECT_TRUE( - ck::utils::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!")); -} - -TEST(ReferenceConvolutionFWD, Conv3DNCDHWStridesDilations) -{ - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 1; - params.K_ = 2; - params.C_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3, 3}; - params.input_spatial_lengths_ = std::vector{12, 12, 12}; - params.conv_filter_strides_ = std::vector{3, 3, 3}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{0, 0, 0}; - params.input_right_pads_ = std::vector{0, 0, 0}; - - auto out_tensor = run_reference_convolution_forward<3, - float, - float, - float, - ck::tensor_layout::convolution::NCDHW, - ck::tensor_layout::convolution::KCZYX, - ck::tensor_layout::convolution::NKDHW>( - params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); - std::vector ref_dims{1, 2, 4, 4, 4}; - std::vector ref_data{ - 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, - 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, - 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, - 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., - 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., - 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, - 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, - 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801, - 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, - 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, - 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, - 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., - 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., - 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, - 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, - 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; - EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error [case 2]: wrong output tensor dimensions!")); - EXPECT_TRUE(ck::utils::check_err( - out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f)); -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace { +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template , + typename FillWeightsOp = ck::utils::FillConstant> +Tensor +run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp{}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) +{ + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::utils::conv::get_host_tensor_descriptor(input_dims, InLayout{})); + Tensor weights( + ck::utils::conv::get_host_tensor_descriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); + + fill_input_op(input.begin(), input.end()); + fill_weights_op(weights.begin(), weights.end()); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + return host_output; +} + +} // anonymous namespace + +TEST(ReferenceConvolutionFWD, Conv2DNHWC) +{ + ck::utils::conv::ConvParams params; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6}; + params.conv_filter_strides_ = std::vector{1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1}; + params.input_left_pads_ = std::vector{0, 0}; + params.input_right_pads_ = std::vector{0, 0}; + + auto out_tensor = run_reference_convolution_forward<2>(params); + std::vector ref_dims{1, 1, 4, 4}; + std::vector ref_data{130.5, + 148.5, + 166.5, + 184.5, + 238.5, + 256.5, + 274.5, + 292.5, + 346.5, + 364.5, + 382.5, + 400.5, + 454.5, + 472.5, + 490.5, + 508.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv2DNHWCStridesDilationsPadding) +{ + ck::utils::conv::ConvParams params; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12}; + params.conv_filter_strides_ = std::vector{2, 2}; + params.conv_filter_dilations_ = std::vector{2, 2}; + params.input_left_pads_ = std::vector{1, 1}; + params.input_right_pads_ = std::vector{1, 1}; + + auto out_tensor = run_reference_convolution_forward<2>(params); + std::vector ref_dims = std::vector{1, 2, 5, 5}; + std::vector ref_data{ + 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., + 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, + 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, + 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, + 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWC) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{6}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{0}; + params.input_right_pads_ = std::vector{0}; + + auto out_tensor = + run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 1, 4}; + std::vector ref_data{7.5, 13.5, 19.5, 25.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWCStridesDilationsPadding) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{12}; + params.conv_filter_strides_ = std::vector{2}; + params.conv_filter_dilations_ = std::vector{2}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + auto out_tensor = + run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 2, 5}; + std::vector ref_data{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWCSameOutputSize) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{16}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + auto out_tensor2 = run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + + std::vector ref_dims{2, 16, 16}; + std::vector ref_data{ + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 27., 27., 27., 27., 27., 27., 27., 27., + 27., 27., 27., 27., 27., 27., 27., 27., + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor2.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv3DNCDHW) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6, 6}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; + + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 1, 4, 4, 4}; + std::vector ref_data{ + 407.7, 410.40002, 413.09998, 415.80002, 423.90002, 426.6, 429.30002, 432., + 440.1, 442.80002, 445.5, 448.2, 456.30002, 459., 461.7, 464.40002, + 504.90002, 507.6, 510.30002, 513., 521.1, 523.8, 526.5, 529.2001, + 537.3, 540., 542.7001, 545.4, 553.5, 556.2001, 558.9, 561.6, + 602.10004, 604.8, 607.5, 610.2, 618.3, 621., 623.7, 626.4, + 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, + 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, + 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; + EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 1]: wrong output tensor dimensions!")); + EXPECT_TRUE( + ck::utils::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv3DNCDHWStridesDilations) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12, 12}; + params.conv_filter_strides_ = std::vector{3, 3, 3}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; + + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 2, 4, 4, 4}; + std::vector ref_data{ + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801, + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; + EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 2]: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f)); +} diff --git a/test/softmax/CMakeLists.txt b/test/softmax/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..da80e372eaf8ace600d83c2c7b163e0c4bf0f6bf --- /dev/null +++ b/test/softmax/CMakeLists.txt @@ -0,0 +1,11 @@ +add_custom_target(test_softmax) + +add_gtest_executable(test_softmax_fp32 test_softmax_fp32.cpp) +add_gtest_executable(test_softmax_fp16 test_softmax_fp16.cpp) +add_gtest_executable(test_softmax_int8 test_softmax_int8.cpp) +target_link_libraries(test_softmax_fp32 PRIVATE host_tensor) +target_link_libraries(test_softmax_fp16 PRIVATE host_tensor) +target_link_libraries(test_softmax_int8 PRIVATE host_tensor) +add_dependencies(test_softmax test_softmax_fp32) +add_dependencies(test_softmax test_softmax_fp16) +add_dependencies(test_softmax test_softmax_int8) \ No newline at end of file diff --git a/test/softmax/test_softmax_fp16.cpp b/test/softmax/test_softmax_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cce6a422b6a9f6552056b17b1b61d448493f4a19 --- /dev/null +++ b/test/softmax/test_softmax_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxFP16 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<4>>, // mixed precision + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<8>, I<8>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxFP16, KernelTypes); +TYPED_TEST(TestSoftmaxFP16, Test_FP16) { this->Run(); } diff --git a/test/softmax/test_softmax_fp32.cpp b/test/softmax/test_softmax_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4301a5ae2f8ffb529d9fdad16ed834b78317aaff --- /dev/null +++ b/test/softmax/test_softmax_fp32.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxFP32 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<8>>, // mixed precision + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<4>, I<4>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxFP32, KernelTypes); +TYPED_TEST(TestSoftmaxFP32, Test_FP32) { this->Run(); } diff --git a/test/softmax/test_softmax_int8.cpp b/test/softmax/test_softmax_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dde165295e5fd3a08897473f3325406683cf2a7e --- /dev/null +++ b/test/softmax/test_softmax_int8.cpp @@ -0,0 +1,30 @@ +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxINT8 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<64>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<64>, I<1>, I<16>, I<16>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxINT8, KernelTypes); +TYPED_TEST(TestSoftmaxINT8, Test_INT8) { this->Run(); } diff --git a/test/softmax/test_softmax_util.hpp b/test/softmax/test_softmax_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ca3b47abc288bde79b3ae45c468bbd4fe15d427 --- /dev/null +++ b/test/softmax/test_softmax_util.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/number.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { + +template +std::string serialize_range(const Range& range) +{ + std::stringstream ss; + for(auto& r : range) + { + ss << r << ", "; + } + std::string str = ss.str(); + return std::string(str.begin(), str.end() - 2); +} + +template +class TestSoftmax : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using AccDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + static constexpr index_t Rank = std::tuple_element_t<3, Tuple>{}.value; + static constexpr index_t NumReduceDim = std::tuple_element_t<4, Tuple>{}.value; + static constexpr index_t BlockSize = std::tuple_element_t<5, Tuple>{}.value; + static constexpr index_t MThreadClusterSize = std::tuple_element_t<6, Tuple>{}.value; + static constexpr index_t KThreadClusterSize = std::tuple_element_t<7, Tuple>{}.value; + static constexpr index_t MThreadSliceSize = std::tuple_element_t<8, Tuple>{}.value; + static constexpr index_t KThreadSliceSize = std::tuple_element_t<9, Tuple>{}.value; + static constexpr index_t InSrcVectorDim = std::tuple_element_t<10, Tuple>{}.value; + static constexpr index_t InSrcVectorSize = std::tuple_element_t<11, Tuple>{}.value; + static constexpr index_t OutDstVectorSize = std::tuple_element_t<12, Tuple>{}.value; + + using ReferenceInstance = + tensor_operation::host::ReferenceSoftmax; + + using DeviceInstance = tensor_operation::device::DeviceSoftmax; + + TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {} + + void RunSingle(std::vector in_length, AccDataType alpha, AccDataType beta) + { + std::vector reduce_dims(NumReduceDim); + std::iota(reduce_dims.begin(), reduce_dims.end(), Rank - NumReduceDim); + + Tensor in(in_length); + Tensor out(in_length); + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + Tensor out_ref(out); + + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + in_dev.ToDevice(in.mData.data()); + out_dev.ToDevice(out.mData.data()); + + std::vector i_in_lengths(in.mDesc.GetLengths().begin(), + in.mDesc.GetLengths().end()); + std::vector i_in_strides(in.mDesc.GetStrides().begin(), + in.mDesc.GetStrides().end()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer(i_in_lengths, + i_in_strides, + reduce_dims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + // std::cout << "Skipped due to unsupported argument: " + // << "input lengths = [" << serialize_range(in_length) << "], " + // << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get()); + + ref_instance_invoker_.Run({in, out_ref, alpha, beta, reduce_dims}); + + out_dev.FromDevice(out.mData.data()); + + bool pass; + + if(std::is_same::value) + { + EXPECT_TRUE(pass = ck::utils::check_err( + out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1)); + } + else + { + EXPECT_TRUE(pass = ck::utils::check_err(out.mData, out_ref.mData)); + } + + if(!pass) + { + FAIL() << "Failure in input lengths = [" << serialize_range(in_length) << "], " + << "scaler = [" << alpha << ", " << beta << "]."; + } + } + + void Run() + { + for(auto in_length : this->in_lengths_) + { + for(auto scale : this->scales_) + { + this->RunSingle(in_length, scale[0], scale[1]); + } + } + } + + std::vector> in_lengths_ = { + {1, 8, 128}, {2, 128, 1024}, {3, 9, 1032}, {4, 4, 2048}, {8, 1, 8192}}; + std::vector> scales_ = {{1, 0}, {1, 1}, {0, 1}, {2, 2}}; + + typename ReferenceInstance::Invoker ref_instance_invoker_; +}; +} // namespace ck diff --git a/test/space_filling_curve/space_filling_curve.cpp b/test/space_filling_curve/space_filling_curve.cpp index 635d31d6830b525b1e5104c42cdf61b6ea3255e7..500717dd2bac7ffe5b748cffa5190b3639d5af11 100644 --- a/test/space_filling_curve/space_filling_curve.cpp +++ b/test/space_filling_curve/space_filling_curve.cpp @@ -1,9 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include #include -#include "tensor_space_filling_curve.hpp" +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" using namespace ck;