diff --git a/.gitignore b/.gitignore
index 294863ce8ac98840299ea4dfcb8d78ddb8249eb1..cdf5b64dece05d4fe72e023c9859a60f76124497 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,4 +45,4 @@ build*
*~
# GDB temporary files
-.gdb_history
\ No newline at end of file
+.gdb_history
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b7ad225e2bdd883ae408e4502ba8fd4870f4954c..3c59f574874d99cadf36e0b886103e5f6dfb49da 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,7 +13,8 @@ add_definitions(-DCK_NOGPU)
endif()
if(NOT CK_NOGPU)
-find_package(ROCM REQUIRED PATHS /opt/rocm)
+set(ROCM_SYMLINK_LIBS OFF)
+find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm)
include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
@@ -22,7 +23,7 @@ include(ROCMInstallSymlinks)
include(ROCMCreatePackage)
include(CheckCXXCompilerFlag)
-rocm_setup_version(VERSION 1.0.0)
+rocm_setup_version(VERSION 0.2.0)
include(TargetFlags)
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
endif()
@@ -84,19 +85,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH )
endif()
message(STATUS "Build with HIP ${HIP_VERSION}")
-
-rocm_create_package(
- NAME composablekernel
- DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
- MAINTAINER "MIOpen Kernels Dev Team
"
- LDCONFIG
-)
-endif()
-
-## half
-set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/include/half")
-message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
-
## tidy
include(EnableCompilerWarnings)
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
@@ -250,7 +238,6 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/include
- ${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR}/library/include
)
@@ -264,6 +251,11 @@ message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
+rocm_package_setup_component(tests
+ LIBRARY_NAME composablekernel
+ PACKAGE_NAME tests # Prevent -static suffix on package name
+)
+
add_subdirectory(library)
add_subdirectory(example)
add_subdirectory(test)
@@ -285,8 +277,19 @@ configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
NO_CHECK_REQUIRED_COMPONENTS_MACRO
)
-install(FILES
+rocm_install(FILES
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
+
+set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
+set(CPACK_RPM_PACKAGE_LICENSE "MIT")
+
+rocm_create_package(
+ NAME composablekernel
+ DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
+ MAINTAINER "MIOpen Kernels Dev Team "
+ LDCONFIG
+ HEADER_ONLY
+)
diff --git a/Dockerfile b/Dockerfile
index 79c961144a3af60b32a95a876111ab4a870596e1..0d32b52f75ac89b0138810af692d7a8177e38f0e 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -88,3 +88,8 @@ ADD rbuild.ini /rbuild.ini
ADD dev-requirements.txt dev-requirements.txt
RUN rbuild prepare -s develop -d $PREFIX
RUN groupadd -f render
+
+# Install the new rocm-cmake version
+RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \
+ cd rocm-cmake && mkdir build && cd build && \
+ cmake .. && cmake --build . && cmake --build . --target install
diff --git a/Jenkinsfile b/Jenkinsfile
index beac2ea248fb390150661537625211d7e0abbf8b..15be3e540c49aef417b4f5401eb75d67d41c4465 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -7,7 +7,6 @@ def show_node_info() {
echo "NODE_NAME = \$NODE_NAME"
lsb_release -sd
uname -r
- cat /sys/module/amdgpu/version
ls /opt/ -la
"""
}
@@ -101,7 +100,8 @@ def buildHipClangJob(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
- gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
+
+ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
if (params.USE_DOCKERFILE){
try {
retimage = docker.build("${image}", dockerArgs + '.')
@@ -191,7 +191,8 @@ def runCKProfiler(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
- gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
+
+ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
if (params.USE_DOCKERFILE){
try {
retimage = docker.build("${image}", dockerArgs + '.')
@@ -317,6 +318,7 @@ pipeline {
dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}"
+ status_wrapper_creds = "${status_wrapper_creds}"
}
stages{
stage("Static checks") {
@@ -386,7 +388,7 @@ pipeline {
agent{ label rocmnode("gfx908")}
environment{
setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
- execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """
+ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc .. && make -j """
}
steps{
buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..2fe9a8455efaeda2eab474b2aa038ec2d9e76841
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,28 @@
+Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang)
+Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang)
+Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan)
+Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang)
+Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah)
+Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
+Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
+
+SPDX-License-Identifier: MIT
+Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 9d7b578046a5e11cafae9ac91ac9419dbf02050a..aa1100dd1381907904ecdfb479ec5aa2609c8798 100644
--- a/README.md
+++ b/README.md
@@ -6,10 +6,13 @@ docker run \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
-rocm/tensorflow:rocm4.3.1-tf2.6-dev \
+rocm/tensorflow:rocm5.1-tf2.6-dev \
/bin/bash
```
+# Install the new rocm-cmake version
+https://github.com/RadeonOpenCompute/rocm-cmake
+
## Build
```bash
mkdir build && cd build
@@ -23,6 +26,7 @@ cmake \
-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
+-D CMAKE_INSTALL_PREFIX=${PATH_TO_CK_INSTALL_DIRECTORY} \
..
```
@@ -34,7 +38,7 @@ Instructions for running each individual examples are under ```example/```
## Tests
```bash
- make -j tests
+ make -j examples tests
make test
```
@@ -44,6 +48,12 @@ Instructions for running each individual examples are under ```example/```
```
Instructions for running ckProfiler are under ```profiler/```
+## Install CK
+```bash
+make install
+```
+
+## Using CK as pre-built kernel library
## Caveat
### Kernel Timing and Verification
diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9e741192f90b8216e4b3abe32ae8971fb45ddfee
--- /dev/null
+++ b/client_example/01_gemm/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_executable(client_gemm gemm.cpp)
+target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations)
diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9b7b7a66039b8114dfa10699cf1996383a56e27e
--- /dev/null
+++ b/client_example/01_gemm/gemm.cpp
@@ -0,0 +1,218 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using Row = ck::tensor_layout::gemm::RowMajor;
+using Col = ck::tensor_layout::gemm::ColumnMajor;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CElementOp = PassThrough;
+
+using ADataType = F16;
+using BDataType = F16;
+using CDataType = F16;
+
+using ALayout = Row;
+using BLayout = Col;
+using CLayout = Row;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // GEMM shape
+ ck::index_t M = 3840;
+ ck::index_t N = 4096;
+ ck::index_t K = 4096;
+
+ ck::index_t StrideA = 4096;
+ ck::index_t StrideB = 4096;
+ ck::index_t StrideC = 4096;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 5)
+ {
+ M = std::stoi(argv[1]);
+ N = std::stoi(argv[2]);
+ K = std::stoi(argv[3]);
+
+ StrideA = std::stoi(argv[4]);
+ StrideB = std::stoi(argv[5]);
+ StrideC = std::stoi(argv[6]);
+ }
+ else
+ {
+ printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n");
+ exit(0);
+ }
+
+ auto f_matrix_space_size =
+ [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
+ using Layout = decltype(layout);
+
+ if(std::is_same::value)
+ {
+ return (nRow - 1) * stride + nCol;
+ }
+ else
+ {
+ return (nCol - 1) * stride + nRow;
+ }
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
+ SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
+
+ using DeviceOp =
+ ck::tensor_operation::device::DeviceGemm;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto c_element_op = CElementOp{};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ c_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ a_element_op,
+ b_element_op,
+ c_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ // run the best intance
+ {
+ auto& op_ptr = op_ptrs[best_op_id];
+
+ std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
+ << std::endl;
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ c_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ a_element_op,
+ b_element_op,
+ c_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
+ }
+
+ std::cout << "Done" << std::endl;
+ }
+
+ return 0;
+}
diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1064abc8fa813c837d2f85ad61e340517a24e70d
--- /dev/null
+++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
+target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations)
diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dbf2e634f0c9aa11d10639e58576988bef7883c3
--- /dev/null
+++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp
@@ -0,0 +1,241 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using Row = ck::tensor_layout::gemm::RowMajor;
+using Col = ck::tensor_layout::gemm::ColumnMajor;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = AddAddFastGelu;
+
+using ADataType = F16;
+using BDataType = F16;
+using D0DataType = F16;
+using D1DataType = F16;
+using EDataType = F16;
+
+using ALayout = Row;
+using BLayout = Col;
+using DDELayout = Row;
+using DDELayout = Row;
+using DELayout = Row;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // GEMM shape
+ ck::index_t M = 3840;
+ ck::index_t N = 4096;
+ ck::index_t K = 4096;
+
+ ck::index_t StrideA = 4096;
+ ck::index_t StrideB = 4096;
+ ck::index_t StrideD0 = 0;
+ ck::index_t StrideD1 = 4096;
+ ck::index_t StrideE = 4096;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 9)
+ {
+ M = std::stoi(argv[1]);
+ N = std::stoi(argv[2]);
+ K = std::stoi(argv[3]);
+
+ StrideA = std::stoi(argv[4]);
+ StrideB = std::stoi(argv[5]);
+ StrideD0 = std::stoi(argv[6]);
+ StrideD1 = std::stoi(argv[7]);
+ StrideE = std::stoi(argv[8]);
+ }
+ else
+ {
+ printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
+ exit(0);
+ }
+
+ auto f_matrix_space_size =
+ [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
+ using Layout = decltype(layout);
+
+ if(std::is_same::value)
+ {
+ return (nRow - 1) * stride + nCol;
+ }
+ else
+ {
+ return (nCol - 1) * stride + nRow;
+ }
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
+ SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) *
+ f_matrix_space_size(M, N, StrideD0, DDELayout{}));
+ SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) *
+ f_matrix_space_size(M, N, StrideD1, DDELayout{}));
+ SimpleDeviceMem e_device_buf(sizeof(EDataType) *
+ f_matrix_space_size(M, N, StrideE, DELayout{}));
+
+ using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
+ ALayout,
+ BLayout,
+ DDELayout,
+ ADataType,
+ BDataType,
+ ck::Tuple,
+ EDataType,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::AddAddFastGelu>;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto cde_element_op = CDEElementOp{};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(
+ a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d0_m_n_device_buf.GetDeviceBuffer(),
+ d1_m_n_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ std::array{StrideD0, StrideD1},
+ StrideE,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ // run the best intance
+ {
+ auto& op_ptr = op_ptrs[best_op_id];
+
+ std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
+ << std::endl;
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(
+ a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d0_m_n_device_buf.GetDeviceBuffer(),
+ d1_m_n_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ std::array{StrideD0, StrideD1},
+ StrideE,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
+ }
+
+ std::cout << "Done" << std::endl;
+ }
+
+ return 0;
+}
diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3742e70844b96575e263b22a14b0bb8c4cde7a43
--- /dev/null
+++ b/client_example/03_gemm_layernorm/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
+target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8f142937281a712d1004e15a578fc64d6501d473
--- /dev/null
+++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp
@@ -0,0 +1,271 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
+#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp"
+#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp"
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using ADataType = F16;
+using BDataType = F16;
+using BiasDataType = F32;
+using CDataType = F16;
+using D0DataType = F16;
+using ReduceDataType = F32;
+using GammaDataType = F16;
+using BetaDataType = F16;
+using LayerNormOutDataType = F16;
+
+using ALayout = ck::tensor_layout::gemm::RowMajor;
+using BLayout = ck::tensor_layout::gemm::ColumnMajor;
+using CLayout = ck::tensor_layout::gemm::RowMajor;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+template
+bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op,
+ const void* p_a,
+ const void* p_b,
+ const void* p_bias,
+ const void* p_d0,
+ void* p_c,
+ void* p_mean,
+ void* p_square_mean,
+ int M,
+ int N,
+ int K,
+ int StrideA,
+ int StrideB,
+ int StrideC,
+ int StrideD0,
+ bool time_kernel)
+{
+ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+ using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
+ using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
+
+ auto passOp = PassThrough{};
+ auto squareOp = UnarySquareElementOp{};
+ auto divOp = UnaryDivElementOp{N};
+
+ auto argument_ptr =
+ p_op->MakeArgumentPointer(p_a,
+ p_b,
+ p_bias,
+ {p_d0},
+ p_c,
+ {p_mean, p_square_mean},
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ {StrideD0},
+ {&passOp, &passOp, &passOp}, // functor for a, b, c
+ {&passOp}, // functor for d0
+ {&passOp, &squareOp}, // functor for inputs of reduction
+ {&divOp, &divOp}); // functor for outputs of reduction
+
+ if(p_op->IsSupportedArgument(argument_ptr.get()))
+ {
+ auto invoker_ptr = p_op->MakeInvokerPointer();
+
+ // If we evaluate running time of gemm_reduce. The output may wrong.
+ // Because we need to initialize the reduction tensor before runing the kernel.
+ // However we run kernel many times for time_kernel = trie without reinitialize the out
+ // of reduction tensor.
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
+
+ if(time_kernel)
+ std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
+
+ return true;
+ }
+
+ return false;
+}
+
+template
+bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
+ const void* p_x,
+ const void* p_mean,
+ const void* p_square_mean,
+ const void* p_gamma,
+ const void* p_beta,
+ void* p_y,
+ int M,
+ int N,
+ int StrideX,
+ bool time_kernel)
+{
+ std::array input = {p_x, p_mean, p_square_mean, p_gamma, p_beta};
+ std::array output = {p_y};
+ auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
+
+ auto argument_ptr = p_op->MakeArgumentPointer(input,
+ output,
+ {M, N},
+ {{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
+ {{StrideX, 1}},
+ ck::tensor_operation::element_wise::Normalize{});
+
+ if(p_op->IsSupportedArgument(argument_ptr.get()))
+ {
+ auto invoker_ptr = p_op->MakeInvokerPointer();
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
+
+ if(time_kernel)
+ std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
+
+ return true;
+ }
+
+ return false;
+}
+
+int main()
+{
+ ck::index_t M = 1024;
+ ck::index_t N = 1024;
+ ck::index_t K = 1024;
+
+ ck::index_t StrideA = 1024;
+ ck::index_t StrideB = 1024;
+ ck::index_t StrideC = 1024;
+ ck::index_t StrideD0 = 1024;
+
+ const auto gemm_reduce_ptrs =
+ ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances<
+ ADataType,
+ BDataType,
+ CDataType,
+ ALayout,
+ BLayout,
+ CLayout>();
+
+ const auto normalize_ptrs =
+ ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
+ CDataType,
+ ReduceDataType,
+ ReduceDataType,
+ GammaDataType,
+ BetaDataType,
+ LayerNormOutDataType>();
+
+ std::cout << "found " << gemm_reduce_ptrs.size()
+ << " gemm_reduceMean_reduceSquareMean instances" << std::endl;
+
+ std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
+
+ auto f_matrix_space_size =
+ [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
+ using Layout = decltype(layout);
+
+ if(std::is_same::value)
+ {
+ return (nRow - 1) * stride + nCol;
+ }
+ else
+ {
+ return (nCol - 1) * stride + nRow;
+ }
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
+ SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N);
+ SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
+ SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
+ f_matrix_space_size(M, N, StrideD0, CLayout{}));
+ SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M);
+ SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M);
+ SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
+ SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
+ SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N);
+
+ bool b_time_kernel = true;
+ bool b_only_run_first_kernel = true;
+
+ // layernorm => (1) + (2)
+ // (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c)
+ // (2). normalize(c, mean, square_mean, gamma, beta)
+ for(auto& gemm_reduce_ptr : gemm_reduce_ptrs)
+ {
+ // run first available kernel
+ if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr,
+ a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ bias_device_buf.GetDeviceBuffer(),
+ d0_device_buf.GetDeviceBuffer(),
+ c_device_buf.GetDeviceBuffer(),
+ reduceMean_device_buf.GetDeviceBuffer(),
+ reduceMeanSquare_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ StrideD0,
+ b_time_kernel))
+ {
+ if(b_only_run_first_kernel)
+ break;
+ }
+ else
+ {
+ std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem"
+ << std::endl;
+ }
+ }
+
+ for(auto& normalize_ptr : normalize_ptrs)
+ {
+ if(RunDeviceNormalize2D(normalize_ptr,
+ c_device_buf.GetDeviceBuffer(),
+ reduceMean_device_buf.GetDeviceBuffer(),
+ reduceMeanSquare_device_buf.GetDeviceBuffer(),
+ gamma_device_buf.GetDeviceBuffer(),
+ beta_device_buf.GetDeviceBuffer(),
+ layerNorm_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ StrideC,
+ b_time_kernel))
+ {
+ if(b_only_run_first_kernel)
+ break;
+ }
+ else
+ {
+ std::cout << normalize_ptr->GetTypeString() << " does not support this problem"
+ << std::endl;
+ }
+ }
+}
diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4bc6780f96d2fe4a4912e3c188b4b5155cc162dd
--- /dev/null
+++ b/client_example/04_contraction/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_executable(client_contraction_scale contraction_scale.cpp)
+target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations)
+
+add_executable(client_contraction_bilinear contraction_bilinear.cpp)
+target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations)
+
diff --git a/client_example/04_contraction/contraction_bilinear.cpp b/client_example/04_contraction/contraction_bilinear.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b71c51c02620ce62257e3b33a6165a1c8ddda2b1
--- /dev/null
+++ b/client_example/04_contraction/contraction_bilinear.cpp
@@ -0,0 +1,241 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
+
+using F32 = float;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+using Bilinear = ck::tensor_operation::element_wise::Bilinear;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = Bilinear;
+
+using ADataType = F32;
+using BDataType = F32;
+using AccDataType = F32;
+using CShuffleDataType = F32;
+using DDataType = F32;
+using DsDataType = ck::Tuple;
+using EDataType = F32;
+
+static constexpr ck::index_t NumDimM = 2;
+static constexpr ck::index_t NumDimN = 2;
+static constexpr ck::index_t NumDimK = 2;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // A[M0, M1, K0, K1]
+ std::vector a_ms_ks_lengths{30, 128, 32, 64};
+ std::vector a_ms_ks_strides{524288, 4096, 128, 1};
+ // B[N0, N1, K0, K1]
+ std::vector b_ns_ks_lengths{32, 64, 32, 64};
+ std::vector b_ns_ks_strides{524288, 4096, 128, 1};
+ // D[M0, M1, N0, N1]
+ std::vector d_ms_ns_lengths{30, 128, 32, 64};
+ std::vector d_ms_ns_strides{524288, 4096, 128, 1};
+ // E[M0, M1, N0, N1]
+ std::vector e_ms_ns_lengths{30, 128, 32, 64};
+ std::vector e_ms_ns_strides{524288, 4096, 128, 1};
+
+ float alpha = 1.f;
+ float beta = 1.f;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 25)
+ {
+ const ck::index_t M0 = std::stoi(argv[1]);
+ const ck::index_t M1 = std::stoi(argv[2]);
+
+ const ck::index_t N0 = std::stoi(argv[3]);
+ const ck::index_t N1 = std::stoi(argv[4]);
+
+ const ck::index_t K0 = std::stoi(argv[5]);
+ const ck::index_t K1 = std::stoi(argv[6]);
+
+ a_ms_ks_lengths = {M0, M1, K0, K1};
+ a_ms_ks_strides = {
+ std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
+
+ b_ns_ks_lengths = {N0, N1, K0, K1};
+ b_ns_ks_strides = {
+ std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
+
+ d_ms_ns_lengths = {M0, M1, N0, N1};
+ d_ms_ns_strides = {
+ std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
+
+ e_ms_ns_lengths = {M0, M1, N0, N1};
+ e_ms_ns_strides = {
+ std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])};
+
+ alpha = std::stof(argv[23]);
+ beta = std::stof(argv[24]);
+ }
+ else
+ {
+ printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
+ printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
+ printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
+ printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
+ printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
+ printf("arg23 to 24: alpha, beta\n");
+ exit(0);
+ }
+
+ auto f_tensor_space_size = [](auto lengths, auto strides) {
+ std::size_t space_size = 1;
+ for(std::size_t i = 0; i < lengths.size(); ++i)
+ {
+ space_size += (lengths[i] - 1) * strides[i];
+ }
+ return space_size;
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) *
+ f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) *
+ f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
+ SimpleDeviceMem d_device_buf(sizeof(DDataType) *
+ f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides));
+ SimpleDeviceMem e_device_buf(sizeof(EDataType) *
+ f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
+
+ using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
+ NumDimM,
+ NumDimN,
+ NumDimK,
+ ADataType,
+ BDataType,
+ ck::Tuple,
+ EDataType,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::Bilinear>;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto cde_element_op = CDEElementOp{alpha, beta};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr =
+ op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ a_ms_ks_lengths,
+ a_ms_ks_strides,
+ b_ns_ks_lengths,
+ b_ns_ks_strides,
+ std::array, 1>{d_ms_ns_lengths},
+ std::array, 1>{d_ms_ns_strides},
+ e_ms_ns_lengths,
+ e_ms_ns_strides,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
+ e_ms_ns_lengths.begin() + NumDimM,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
+ e_ms_ns_lengths.begin() + NumDimM + NumDimN,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
+ a_ms_ks_lengths.begin() + NumDimM + NumDimK,
+ ck::index_t{1},
+ std::multiplies{});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+ std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
+ sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ return 0;
+}
diff --git a/client_example/04_contraction/contraction_scale.cpp b/client_example/04_contraction/contraction_scale.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5908c1d86e678796dec3d2616c83e9fca40595fb
--- /dev/null
+++ b/client_example/04_contraction/contraction_scale.cpp
@@ -0,0 +1,227 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp"
+
+using F32 = float;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+using Scale = ck::tensor_operation::element_wise::Scale;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = Scale;
+
+using ADataType = F32;
+using BDataType = F32;
+using AccDataType = F32;
+using CShuffleDataType = F32;
+using DsDataType = ck::Tuple<>;
+using EDataType = F32;
+
+static constexpr ck::index_t NumDimM = 2;
+static constexpr ck::index_t NumDimN = 2;
+static constexpr ck::index_t NumDimK = 2;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // A[M0, M1, K0, K1]
+ std::vector a_ms_ks_lengths{30, 128, 32, 64};
+ std::vector a_ms_ks_strides{524288, 4096, 128, 1};
+ // B[N0, N1, K0, K1]
+ std::vector b_ns_ks_lengths{32, 64, 32, 64};
+ std::vector b_ns_ks_strides{524288, 4096, 128, 1};
+ // E[M0, M1, N0, N1]
+ std::vector e_ms_ns_lengths{30, 128, 32, 64};
+ std::vector e_ms_ns_strides{524288, 4096, 128, 1};
+
+ float scale = 1.f;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 20)
+ {
+ const ck::index_t M0 = std::stoi(argv[1]);
+ const ck::index_t M1 = std::stoi(argv[2]);
+
+ const ck::index_t N0 = std::stoi(argv[3]);
+ const ck::index_t N1 = std::stoi(argv[4]);
+
+ const ck::index_t K0 = std::stoi(argv[5]);
+ const ck::index_t K1 = std::stoi(argv[6]);
+
+ a_ms_ks_lengths = {M0, M1, K0, K1};
+ a_ms_ks_strides = {
+ std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
+
+ b_ns_ks_lengths = {N0, N1, K0, K1};
+ b_ns_ks_strides = {
+ std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
+
+ e_ms_ns_lengths = {M0, M1, N0, N1};
+ e_ms_ns_strides = {
+ std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
+
+ scale = std::stof(argv[19]);
+ }
+ else
+ {
+ printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
+ printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
+ printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
+ printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
+ printf("arg19: scale\n");
+ exit(0);
+ }
+
+ auto f_tensor_space_size = [](auto lengths, auto strides) {
+ std::size_t space_size = 1;
+ for(std::size_t i = 0; i < lengths.size(); ++i)
+ {
+ space_size += (lengths[i] - 1) * strides[i];
+ }
+ return space_size;
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) *
+ f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) *
+ f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
+ SimpleDeviceMem e_device_buf(sizeof(EDataType) *
+ f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
+
+ using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
+ NumDimM,
+ NumDimN,
+ NumDimK,
+ ADataType,
+ BDataType,
+ ck::Tuple<>,
+ EDataType,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::Scale>;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto cde_element_op = CDEElementOp{scale};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{},
+ e_device_buf.GetDeviceBuffer(),
+ a_ms_ks_lengths,
+ a_ms_ks_strides,
+ b_ns_ks_lengths,
+ b_ns_ks_strides,
+ std::array, 0>{},
+ std::array, 0>{},
+ e_ms_ns_lengths,
+ e_ms_ns_strides,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
+ e_ms_ns_lengths.begin() + NumDimM,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
+ e_ms_ns_lengths.begin() + NumDimM + NumDimN,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
+ a_ms_ks_lengths.begin() + NumDimM + NumDimK,
+ ck::index_t{1},
+ std::multiplies{});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ return 0;
+}
diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3e04a18599a7b488fb306cbaf598494bd48b69d5
--- /dev/null
+++ b/client_example/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 3.15)
+project(ck_app)
+add_compile_options(-std=c++17)
+
+find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
+find_package(hip REQUIRED PATHS /opt/rocm)
+message(STATUS "Build with HIP ${hip_VERSION}")
+
+add_subdirectory(01_gemm)
+add_subdirectory(02_gemm_add_add_fastgelu)
+add_subdirectory(03_gemm_layernorm)
+add_subdirectory(04_contraction)
diff --git a/client_example/README.md b/client_example/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..64a7130d537b1e2fb8752c4031e8430d11a6a46a
--- /dev/null
+++ b/client_example/README.md
@@ -0,0 +1,21 @@
+##
+Client application links to CK library, and therefore CK library needs to be installed before building client applications.
+
+
+## Build
+```bash
+mkdir -p client_example/build
+cd client_example/build
+```
+
+```bash
+cmake \
+-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
+-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \
+..
+```
+
+### Build client example
+```bash
+ make -j
+```
diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake
index 959bc4f4b0e26bb9c8e86a68eb34ed692041722c..3718b916ffe43996852507881db281dc5647fef0 100644
--- a/cmake/googletest.cmake
+++ b/cmake/googletest.cmake
@@ -8,7 +8,7 @@ endif()
message(STATUS "Fetching GoogleTest")
-list(APPEND GTEST_CMAKE_CXX_FLAGS
+list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-undef
-Wno-reserved-identifier
-Wno-global-constructors
@@ -31,7 +31,11 @@ FetchContent_Declare(
# Will be necessary for windows build
# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
-FetchContent_MakeAvailable(googletest)
+FetchContent_GetProperties(googletest)
+if(NOT googletest_POPULATED)
+ FetchContent_Populate(googletest)
+ add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
+endif()
target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp
index 9a22628777c06806990bb9a9972e8d773a7a92f5..0a3060fdc71b22cd655634c7b5d01b00363dffee 100644
--- a/example/01_gemm/gemm_dl_fp16.cpp
+++ b/example/01_gemm/gemm_dl_fp16.cpp
@@ -1,20 +1,21 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_dl.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/utility/check_err.hpp"
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
using S = ck::Sequence;
diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp
index 32b183a3a160e5ffd05bfda859a8bfaea01bdfd5..d9677da9b9fd6aa2578cb20b3176e5c5d45b0ffd 100644
--- a/example/01_gemm/gemm_dl_fp32.cpp
+++ b/example/01_gemm/gemm_dl_fp32.cpp
@@ -1,20 +1,21 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_dl.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/utility/check_err.hpp"
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
using S = ck::Sequence;
diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp
index 16c9213104a8f99572dac365c622862ccd10a57f..65206d602f66eb800c783bace5a784fadee0c86a 100644
--- a/example/01_gemm/gemm_dl_int8.cpp
+++ b/example/01_gemm/gemm_dl_int8.cpp
@@ -1,20 +1,21 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_dl.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/utility/check_err.hpp"
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
using S = ck::Sequence;
diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp
index b126736be656e4c0f90136cd0badc65ae5c491de..0575c0bd9e2fa89a5f8823d7a7796d3d75a50ffd 100644
--- a/example/01_gemm/gemm_xdl_bf16.cpp
+++ b/example/01_gemm/gemm_xdl_bf16.cpp
@@ -1,20 +1,21 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_xdl_cshuffle.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
+#include "ck/library/utility/check_err.hpp"
template
using S = ck::Sequence;
@@ -83,8 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
-using ReferenceGemmInstance = ck::tensor_operation::host::
- ReferenceGemm;
+using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm;
int main(int argc, char* argv[])
{
@@ -215,24 +221,17 @@ int main(int argc, char* argv[])
if(do_verification)
{
- Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
- Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
- Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
- Tensor c_m_n_device_f32_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
-
- bf16_to_f32_(a_m_k, a_f32_m_k);
- bf16_to_f32_(b_k_n, b_f32_k_n);
- bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
+ Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
- a_f32_m_k, b_f32_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
+ a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
- return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
+ return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
}
return 0;
diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp
index 003534f79aa536f8f7aa374baab12c9a2668c06f..0d194403773b1564ba179d924b267b7e91d0e4e9 100644
--- a/example/01_gemm/gemm_xdl_fp16.cpp
+++ b/example/01_gemm/gemm_xdl_fp16.cpp
@@ -1,20 +1,22 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_xdl.hpp"
-#include "device_gemm_xdl_cshuffle.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/utility/check_err.hpp"
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
using S = ck::Sequence;
@@ -27,30 +29,42 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
-using ADataType = ck::half_t;
-using BDataType = ck::half_t;
-using CDataType = ck::half_t;
-using AccDataType = float;
+using ADataType = F16;
+using BDataType = F16;
+using AccDataType = F32;
+using CShuffleDataType = F32;
+using CDataType = F16;
-using ALayout = ck::tensor_layout::gemm::RowMajor;
-using BLayout = ck::tensor_layout::gemm::ColumnMajor;
-using CLayout = ck::tensor_layout::gemm::RowMajor;
+using ALayout = Row;
+using BLayout = Col;
+using CLayout = Row;
-using AElementOp = ck::tensor_operation::element_wise::PassThrough;
-using BElementOp = ck::tensor_operation::element_wise::PassThrough;
-using CElementOp = ck::tensor_operation::element_wise::PassThrough;
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
-// clang-format off
-using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
-//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
-//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
-//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
-//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
- < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
+using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
+ // clang-format off
+//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
+//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
+//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
+//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+ < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
+// clang-format on
+
+using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
+ // clang-format off
+//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
+//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
+//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
+//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+ < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
+using DeviceGemmInstance = DeviceGemmInstance0;
+
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm;
@@ -69,7 +83,11 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
- if(argc == 4)
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -93,7 +111,7 @@ int main(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
- printf("arg3: time kernel (0=n0, 1=yes)\n");
+ printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp
index 7cea68c8b0f11858b697c3cacb38f473632f0c61..1b222c971267102dbd3cbd7465aaf82009d6ecd9 100644
--- a/example/01_gemm/gemm_xdl_fp64.cpp
+++ b/example/01_gemm/gemm_xdl_fp64.cpp
@@ -1,21 +1,22 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_xdl.hpp"
-#include "device_gemm_xdl_cshuffle.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
+#include "ck/library/utility/check_err.hpp"
template
using S = ck::Sequence;
diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp
index 27fcd62a2c13b1031ee1fdccd5fafe423ae8227e..4ed1f177db6d0e5df668256f232d631ca9f2464a 100644
--- a/example/01_gemm/gemm_xdl_int8.cpp
+++ b/example/01_gemm/gemm_xdl_int8.cpp
@@ -1,20 +1,22 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
#include
#include
#include
#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "device_tensor.hpp"
-#include "device_gemm_xdl_cshuffle.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm.hpp"
-#include "gemm_specialization.hpp"
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/utility/check_err.hpp"
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
using S = ck::Sequence;
diff --git a/example/02_gemm_alpha_beta/CMakeLists.txt b/example/02_gemm_alpha_beta/CMakeLists.txt
deleted file mode 100644
index 1b81cf21622b6e70cb43dbd4bc90874fc7bf5580..0000000000000000000000000000000000000000
--- a/example/02_gemm_alpha_beta/CMakeLists.txt
+++ /dev/null
@@ -1 +0,0 @@
-add_example_executable(example_gemm_xdl_alpha_beta gemm_xdl_alpha_beta.cpp)
diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp
deleted file mode 100644
index 1a6e1de4dcfb4f75afca02b204e3963dab86b9e7..0000000000000000000000000000000000000000
--- a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp
+++ /dev/null
@@ -1,253 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "print.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "host_gemm.hpp"
-#include "device_tensor.hpp"
-#include "device_base.hpp"
-#include "device_gemm_xdl_c_shuffle_bias_2d.hpp"
-#include "element_wise_operation.hpp"
-#include "reference_gemm_bias_2d.hpp"
-
-template
-using S = ck::Sequence;
-
-using ADataType = ck::half_t;
-using BDataType = ck::half_t;
-using CDataType = ck::half_t;
-using AccDataType = float;
-
-using ALayout = ck::tensor_layout::gemm::RowMajor;
-using BLayout = ck::tensor_layout::gemm::ColumnMajor;
-using CLayout = ck::tensor_layout::gemm::RowMajor;
-
-using AElementOp = ck::tensor_operation::element_wise::PassThrough;
-using BElementOp = ck::tensor_operation::element_wise::PassThrough;
-using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd;
-
-// clang-format off
-using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_2d<
- ADataType, // ADataType
- BDataType, // BDataType
- CDataType, // CDataType
- AccDataType, // AccDataType
- ALayout, // ALayout
- BLayout, // BLayout
- CLayout, // CLayout
- AElementOp, // AElementwiseOperation
- BElementOp, // BElementwiseOperation
- CElementOp, // CElementwiseOperation
- 256, // BlockSize
- 256, // MPerBlock
- 128, // NPerBlock
- 4, // K0PerBlock
- 8, // K1
- 32, // MPerXDL
- 32, // NPerXDL
- 4, // MXdlPerWave
- 2, // NXdlPerWave
- S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
- S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
- S<1, 0, 2>, // ABlockTransferSrcAccessOrder
- 2, // ABlockTransferSrcVectorDim
- 8, // ABlockTransferSrcScalarPerVector
- 8, // ABlockTransferDstScalarPerVector_K1
- true, // ABlockLdsAddExtraM
- S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
- S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
- S<1, 0, 2>, // BBlockTransferSrcAccessOrder
- 2, // BBlockTransferSrcVectorDim
- 8, // BBlockTransferSrcScalarPerVector
- 8, // BBlockTransferDstScalarPerVector_K1
- true, // BBlockLdsAddExtraN
- 1, // CShuffleMXdlPerWavePerShuffle
- 1, // CShuffleNXdlPerWavePerShuffle
- S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
- 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
-// clang-format on
-
-using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D;
-
-int main(int argc, char* argv[])
-{
- bool do_verification = true;
- int init_method = 1;
- bool time_kernel = false;
-
- // GEMM shape
- ck::index_t M = 3840;
- ck::index_t N = 4096;
- ck::index_t K = 4096;
-
- ck::index_t StrideA = 4096;
- ck::index_t StrideB = 4096;
- ck::index_t StrideC = 4096;
-
- float alpha = 1.0f;
- float beta = 1.0f;
-
- if(argc == 4)
- {
- do_verification = std::stoi(argv[1]);
- init_method = std::stoi(argv[2]);
- time_kernel = std::stoi(argv[3]);
- }
- else if(argc == 6)
- {
- do_verification = std::stoi(argv[1]);
- init_method = std::stoi(argv[2]);
- time_kernel = std::stoi(argv[3]);
-
- alpha = std::stof(argv[4]);
- beta = std::stof(argv[5]);
- }
- else if(argc == 12)
- {
- do_verification = std::stoi(argv[1]);
- init_method = std::stoi(argv[2]);
- time_kernel = std::stoi(argv[3]);
-
- M = std::stoi(argv[4]);
- N = std::stoi(argv[5]);
- K = std::stoi(argv[6]);
-
- StrideA = std::stoi(argv[7]);
- StrideB = std::stoi(argv[8]);
- StrideC = std::stoi(argv[9]);
-
- alpha = std::stof(argv[10]);
- beta = std::stof(argv[11]);
- }
- else
- {
- printf("arg1: verification (0=no, 1=yes)\n");
- printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
- printf("arg3: time kernel (0=n0, 1=yes)\n");
- printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n");
- exit(0);
- }
-
- auto f_host_tensor_descriptor =
- [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
- if(std::is_same::value)
- {
- return HostTensorDescriptor(std::vector({row, col}),
- std::vector({stride, 1}));
- }
- else
- {
- return HostTensorDescriptor(std::vector({row, col}),
- std::vector({1, stride}));
- }
- };
-
- Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
- Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
- Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
- Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
- Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
-
- std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
- std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
- std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl;
- std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
-
- switch(init_method)
- {
- case 0: break;
- case 1:
- a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
- b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
- c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
- break;
- default:
- a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0});
- b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5});
- c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5});
- }
-
- DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
- DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
- DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace());
- DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
-
- a_m_k_device_buf.ToDevice(a_m_k.mData.data());
- b_k_n_device_buf.ToDevice(b_k_n.mData.data());
- c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
- c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
-
- // do GEMM
- auto gemm = DeviceGemmInstance{};
- auto invoker = gemm.MakeInvoker();
- auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()),
- static_cast(b_k_n_device_buf.GetDeviceBuffer()),
- static_cast(c0_m_n_device_buf.GetDeviceBuffer()),
- static_cast(c_m_n_device_buf.GetDeviceBuffer()),
- M,
- N,
- K,
- StrideA,
- StrideB,
- StrideC,
- AElementOp{},
- BElementOp{},
- CElementOp{alpha, beta});
-
- if(!gemm.IsSupportedArgument(argument))
- {
- throw std::runtime_error(
- "wrong! device_gemm with the specified compilation parameters does "
- "not support this GEMM problem");
- }
-
- float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
-
- std::size_t flop = std::size_t(2) * M * N * K;
- std::size_t num_btype =
- sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
-
- float tflops = static_cast(flop) / 1.E9 / ave_time;
-
- float gb_per_sec = num_btype / 1.E6 / ave_time;
-
- std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
- << std::endl;
-
- c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
-
- if(do_verification)
- {
- auto ref_gemm = ReferenceGemmInstance{};
- auto ref_invoker = ref_gemm.MakeInvoker();
-
- auto ref_argument = ref_gemm.MakeArgument(a_m_k,
- b_k_n,
- c0_m_n,
- c_m_n_host_result,
- AElementOp{},
- BElementOp{},
- CElementOp{alpha, beta});
-
- ref_invoker.Run(ref_argument);
-
- return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
- }
-
- return 0;
-}
diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..10ec0f1a71151668e262efcdbaff7100d2d08dfa
--- /dev/null
+++ b/example/02_gemm_bilinear/CMakeLists.txt
@@ -0,0 +1 @@
+add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
diff --git a/example/02_gemm_alpha_beta/README.md b/example/02_gemm_bilinear/README.md
similarity index 69%
rename from example/02_gemm_alpha_beta/README.md
rename to example/02_gemm_bilinear/README.md
index ba2a3068f3e78757d34f3e9d7f382a76aef19bc5..9eb87e1e3479d72497ec72956b1de649b0ff735f 100644
--- a/example/02_gemm_alpha_beta/README.md
+++ b/example/02_gemm_bilinear/README.md
@@ -1,11 +1,13 @@
-# Instructions for ```example_gemm_xdl_alpha_beta```
+# Instructions for ```example_gemm_bilinear_xdl_fp16```
-## Run ```example_gemm_xdl_alpha_beta```
+## Run ```example_gemm_bilinear_xdl_fp16```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
-#arg3: run kernel # of times (>1)
-./bin/example_gemm_xdl_alpha_beta 1 1 1 0.5 0.5
+#arg3: time kernel (0=no, 1=yes)
+#arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
+#arg11 to 12: alpha, beta
+./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5
```
Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16)
```
diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9b340807ba6f783b48ab860c6776799d14311649
--- /dev/null
+++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
@@ -0,0 +1,305 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
+#include "ck/library/utility/check_err.hpp"
+
+struct AlphaBetaAdd
+{
+ AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
+
+ template
+ __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
+
+ template <>
+ __host__ __device__ constexpr void operator()(
+ ck::half_t& e, const float& c, const ck::half_t& d) const
+ {
+ e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d));
+ };
+
+ float alpha_;
+ float beta_;
+};
+
+template
+using S = ck::Sequence;
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using Row = ck::tensor_layout::gemm::RowMajor;
+using Col = ck::tensor_layout::gemm::ColumnMajor;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+
+using ADataType = F16;
+using BDataType = F16;
+using AccDataType = F32;
+using CShuffleDataType = F32;
+using DDataType = F16;
+using DsDataType = ck::Tuple;
+using EDataType = F16;
+
+using ALayout = Row;
+using BLayout = Col;
+using DELayout = Row;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = AlphaBetaAdd;
+
+static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
+
+using DeviceOpInstance =
+ ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle,
+ S<1, 0, 2>,
+ S<1, 0, 2>,
+ 2,
+ 8,
+ 8,
+ 1,
+ S<4, 64, 1>,
+ S<1, 0, 2>,
+ S<1, 0, 2>,
+ 2,
+ 8,
+ 8,
+ 1,
+ 1,
+ 1,
+ S<1, 32, 1, 8>,
+ 8>;
+
+int main(int argc, char* argv[])
+{
+ bool do_verification = true;
+ int init_method = 1;
+ bool time_kernel = false;
+
+ // GEMM shape
+ ck::index_t M = 3840;
+ ck::index_t N = 4096;
+ ck::index_t K = 4096;
+
+ ck::index_t StrideA = 4096;
+ ck::index_t StrideB = 4096;
+ ck::index_t StrideD = 4096;
+ ck::index_t StrideE = 4096;
+
+ float alpha = 1.0f;
+ float beta = 1.0f;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 4)
+ {
+ do_verification = std::stoi(argv[1]);
+ init_method = std::stoi(argv[2]);
+ time_kernel = std::stoi(argv[3]);
+ }
+ else if(argc == 6)
+ {
+ do_verification = std::stoi(argv[1]);
+ init_method = std::stoi(argv[2]);
+ time_kernel = std::stoi(argv[3]);
+
+ alpha = std::stof(argv[4]);
+ beta = std::stof(argv[5]);
+ }
+ else if(argc == 13)
+ {
+ do_verification = std::stoi(argv[1]);
+ init_method = std::stoi(argv[2]);
+ time_kernel = std::stoi(argv[3]);
+
+ M = std::stoi(argv[4]);
+ N = std::stoi(argv[5]);
+ K = std::stoi(argv[6]);
+
+ StrideA = std::stoi(argv[7]);
+ StrideB = std::stoi(argv[8]);
+ StrideD = std::stoi(argv[9]);
+ StrideE = std::stoi(argv[10]);
+
+ alpha = std::stof(argv[11]);
+ beta = std::stof(argv[12]);
+ }
+ else
+ {
+ printf("arg1: verification (0=no, 1=yes)\n");
+ printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
+ printf("arg3: time kernel (0=no, 1=yes)\n");
+ printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
+ "beta\n");
+ exit(0);
+ }
+
+ auto f_host_tensor_descriptor =
+ [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
+ if(std::is_same::value)
+ {
+ return HostTensorDescriptor(std::vector({row, col}),
+ std::vector({stride, 1}));
+ }
+ else
+ {
+ return HostTensorDescriptor(std::vector({row, col}),
+ std::vector({1, stride}));
+ }
+ };
+
+ Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
+ Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
+ Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DELayout{}));
+ Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
+ Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
+
+ std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
+ std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
+ std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
+ std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
+
+ switch(init_method)
+ {
+ case 0: break;
+ case 1:
+ a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
+ b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
+ d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
+ break;
+ default:
+ a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0});
+ b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5});
+ d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5});
+ }
+
+ DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
+ DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
+ DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
+ DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
+
+ a_device_buf.ToDevice(a_m_k.mData.data());
+ b_device_buf.ToDevice(b_k_n.mData.data());
+ d_device_buf.ToDevice(d_m_n.mData.data());
+ e_device_buf.ToDevice(e_m_n_device_result.mData.data());
+
+ auto a_element_op = AElementOp{};
+ auto b_element_op = BElementOp{};
+ auto cde_element_op = CDEElementOp{alpha, beta};
+
+ // do GEMM
+ auto device_op = DeviceOpInstance{};
+ auto invoker = device_op.MakeInvoker();
+ auto argument =
+ device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ std::array{StrideD},
+ StrideE,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ if(!device_op.IsSupportedArgument(argument))
+ {
+ throw std::runtime_error(
+ "wrong! device_gemm with the specified compilation parameters does "
+ "not support this GEMM problem");
+ }
+
+ float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
+ << std::endl;
+
+ e_device_buf.FromDevice(e_m_n_device_result.mData.data());
+
+ if(do_verification)
+ {
+ Tensor c_m_n(HostTensorDescriptor(
+ std::vector{static_cast(M), static_cast(N)}));
+
+ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm;
+ auto ref_gemm = ReferenceGemmInstance{};
+ auto ref_invoker = ref_gemm.MakeInvoker();
+
+ auto ref_argument =
+ ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
+
+ ref_invoker.Run(ref_argument);
+
+ for(int m = 0; m < M; ++m)
+ {
+ for(int n = 0; n < N; ++n)
+ {
+ cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
+ }
+ }
+
+ e_device_buf.FromDevice(e_m_n_device_result.mData.data());
+
+ return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
+ }
+
+ return 0;
+}
diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt
index d07ad6e36c3a9f1deda141a66e20945c7fff37c1..35c54abac03094f24187df2503aa02b6812c20f3 100644
--- a/example/03_gemm_bias_relu/CMakeLists.txt
+++ b/example/03_gemm_bias_relu/CMakeLists.txt
@@ -1 +1 @@
-add_example_executable(example_gemm_xdl_bias_relu gemm_xdl_bias_relu.cpp)
+add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
diff --git a/example/03_gemm_bias_relu/README.md b/example/03_gemm_bias_relu/README.md
index f8d9bd6152907de226567aefc85b91de00238e05..f28a9a071c879e92be34f84054661647c31ebb35 100644
--- a/example/03_gemm_bias_relu/README.md
+++ b/example/03_gemm_bias_relu/README.md
@@ -1,28 +1,10 @@
-# Instructions for ```example_gemm_xdl_bias_relu_add```
+# Instructions for ```example_gemm_bias_relu_xdl_fp16```
-## Run ```example_gemm_xdl_bias_relu_add```
+## Run ```example_gemm_bias_relu_xdl_fp16```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
-#arg3: run kernel # of times (>1)
-#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
-./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096
-```
-
-Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
-```
-a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
-b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
-c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
-c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
-c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0}
-arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
-arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
-arg.c_grid_desc_m_n_{ 3840, 4096}
-arg.c0_grid_desc_m_n_{ 3840, 4096}
-arg.c1_grid_desc_m_n_{ 3840, 4096}
-launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
-Warm up
-Start running 5 times...
-Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s
+#arg3: time kernel (0=no, 1=yes)
+#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE
+./bin/example_gemm_bias_relu_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096
```
diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e36280f42db89fe8d24d767365cc4fd40674af4a
--- /dev/null
+++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
@@ -0,0 +1,281 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/host_tensor/device_memory.hpp"
+#include "ck/library/host_tensor/host_tensor.hpp"
+#include "ck/library/host_tensor/host_tensor_generator.hpp"
+#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
+#include "ck/library/utility/check_err.hpp"
+
+template
+using S = ck::Sequence;
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using Row = ck::tensor_layout::gemm::RowMajor;
+using Col = ck::tensor_layout::gemm::ColumnMajor;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+
+// C = A * B
+// E = Relu(C + D);
+struct AddRelu
+{
+ __host__ __device__ void
+ operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
+ {
+ const ck::half_t x = c + d;
+
+ e = x > 0 ? x : 0;
+ }
+};
+
+using ADataType = F16;
+using BDataType = F16;
+using AccDataType = F32;
+using CShuffleDataType = F16;
+using DDataType = F16;
+using DsDataType = ck::Tuple;
+using EDataType = F16;
+
+using ALayout = Row;
+using BLayout = Col;
+using ELayout = Row;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = AddRelu;
+
+static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
+
+using DeviceOpInstance =
+ ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle,
+ S<1, 0, 2>,
+ S<1, 0, 2>,
+ 2,
+ 8,
+ 8,
+ 1,
+ S<4, 64, 1>,
+ S<1, 0, 2>,
+ S<1, 0, 2>,
+ 2,
+ 8,
+ 8,
+ 1,
+ 1,
+ 1,
+ S<1, 32, 1, 8>,
+ 8>;
+
+int main(int argc, char* argv[])
+{
+ bool do_verification = true;
+ int init_method = 1;
+ bool time_kernel = false;
+
+ // GEMM shape
+ ck::index_t M = 3840;
+ ck::index_t N = 4096;
+ ck::index_t K = 4096;
+
+ ck::index_t StrideA = 4096;
+ ck::index_t StrideB = 4096;
+ ck::index_t StrideE = 4096;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 4)
+ {
+ do_verification = std::stoi(argv[1]);
+ init_method = std::stoi(argv[2]);
+ time_kernel = std::stoi(argv[3]);
+ }
+ else if(argc == 10)
+ {
+ do_verification = std::stoi(argv[1]);
+ init_method = std::stoi(argv[2]);
+ time_kernel = std::stoi(argv[3]);
+
+ M = std::stoi(argv[4]);
+ N = std::stoi(argv[5]);
+ K = std::stoi(argv[6]);
+
+ StrideA = std::stoi(argv[7]);
+ StrideB = std::stoi(argv[8]);
+ StrideE = std::stoi(argv[9]);
+ }
+ else
+ {
+ printf("arg1: verification (0=no, 1=yes)\n");
+ printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
+ printf("arg3: time kernel (0=no, 1=yes)\n");
+ printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
+ exit(0);
+ }
+
+ auto f_host_tensor_descriptor =
+ [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
+ if(std::is_same::value)
+ {
+ return HostTensorDescriptor(std::vector({row, col}),
+ std::vector({stride, 1}));
+ }
+ else
+ {
+ return HostTensorDescriptor(std::vector({row, col}),
+ std::vector({1, stride}));
+ }
+ };
+
+ Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
+ Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
+ Tensor d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{}));
+ Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
+ Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
+
+ std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
+ std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
+ std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
+ std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
+
+ switch(init_method)
+ {
+ case 0: break;
+ case 1:
+ a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
+ b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
+ d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
+ break;
+ default:
+ a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0});
+ b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5});
+ d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0});
+ }
+
+ DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
+ DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
+ DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
+ DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
+
+ a_device_buf.ToDevice(a_m_k.mData.data());
+ b_device_buf.ToDevice(b_k_n.mData.data());
+ d_device_buf.ToDevice(d_m_n.mData.data());
+
+ auto a_element_op = AElementOp{};
+ auto b_element_op = BElementOp{};
+ auto cde_element_op = CDEElementOp{};
+
+ // do GEMM
+ auto device_op = DeviceOpInstance{};
+
+ auto invoker = device_op.MakeInvoker();
+
+ auto argument =
+ device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ std::array{0},
+ StrideE,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ if(!device_op.IsSupportedArgument(argument))
+ {
+ throw std::runtime_error("wrong! this device_op instance does not support this problem");
+ }
+
+ float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+
+ std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
+ sizeof(EDataType) * M * N + sizeof(EDataType) * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
+ << std::endl;
+
+ if(do_verification)
+ {
+ e_device_buf.FromDevice(e_m_n_device_result.mData.data());
+
+ Tensor c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
+
+ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm;
+
+ auto ref_gemm = ReferenceGemmInstance{};
+ auto ref_invoker = ref_gemm.MakeInvoker();
+
+ auto ref_argument =
+ ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
+
+ ref_invoker.Run(ref_argument);
+
+ for(int m = 0; m < M; ++m)
+ {
+ for(int n = 0; n < N; ++n)
+ {
+ cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
+ }
+ }
+
+ return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
+ }
+
+ return 0;
+}
diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp
deleted file mode 100644
index 3bf3003c147c7107aaa8cb2bda0eed7b1043ee5a..0000000000000000000000000000000000000000
--- a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include "check_err.hpp"
-#include "config.hpp"
-#include "print.hpp"
-#include "device.hpp"
-#include "host_tensor.hpp"
-#include "host_tensor_generator.hpp"
-#include "host_gemm.hpp"
-#include "device_tensor.hpp"
-#include "element_wise_operation.hpp"
-#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
-#include "reference_gemm_bias_activation.hpp"
-
-template
-using S = ck::Sequence;
-
-using ADataType = ck::half_t;
-using BDataType = ck::half_t;
-using CDataType = ck::half_t;
-using AccDataType = float;
-
-using ALayout = ck::tensor_layout::gemm::RowMajor;
-using BLayout = ck::tensor_layout::gemm::ColumnMajor;
-using CLayout = ck::tensor_layout::gemm::RowMajor;
-
-using AElementOp = ck::tensor_operation::element_wise::PassThrough;
-using BElementOp = ck::tensor_operation::element_wise::PassThrough;
-using CElementOp = ck::tensor_operation::element_wise::AddRelu;
-
-// clang-format off
-using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation<
- ADataType, // ADataType
- BDataType, // BDataType
- CDataType, // CDataType
- AccDataType, // AccDataType
- ALayout, // ALayout
- BLayout, // BLayout
- CLayout, // CLayout
- AElementOp, // AElementwiseOperation
- BElementOp, // BElementwiseOperation
- CElementOp, // CElementwiseOperation
- 256, // BlockSize
- 256, // MPerBlock
- 128, // NPerBlock
- 4, // K0PerBlock
- 8, // K1
- 32, // MPerXDL
- 32, // NPerXDL
- 4, // MXdlPerWave
- 2, // NXdlPerWave
- S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
- S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
- S<1, 0, 2>, // ABlockTransferSrcAccessOrder
- 2, // ABlockTransferSrcVectorDim
- 8, // ABlockTransferSrcScalarPerVector
- 8, // ABlockTransferDstScalarPerVector_K1
- true, // ABlockLdsAddExtraM
- S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
- S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
- S<1, 0, 2>, // BBlockTransferSrcAccessOrder
- 2, // BBlockTransferSrcVectorDim
- 8, // BBlockTransferSrcScalarPerVector
- 8, // BBlockTransferDstScalarPerVector_K1
- true, // BBlockLdsAddExtraN
- 1, // CShuffleMXdlPerWavePerShuffle
- 1, // CShuffleNXdlPerWavePerShuffle
- S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
- 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
-// clang-format on
-
-using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation;
-
-int main(int argc, char* argv[])
-{
- bool do_verification = true;
- int init_method = 1;
- bool time_kernel = false;
-
- // GEMM shape
- ck::index_t M = 3840;
- ck::index_t N = 4096;
- ck::index_t K = 4096;
-
- ck::index_t StrideA = 4096;
- ck::index_t StrideB = 4096;
- ck::index_t StrideC = 4096;
-
- if(argc == 4)
- {
- do_verification = std::stoi(argv[1]);
- init_method = std::stoi(argv[2]);
- time_kernel = std::stoi(argv[3]);
- }
- else if(argc == 10)
- {
- do_verification = std::stoi(argv[1]);
- init_method = std::stoi(argv[2]);
- time_kernel = std::stoi(argv[3]);
-
- M = std::stoi(argv[4]);
- N = std::stoi(argv[5]);
- K = std::stoi(argv[6]);
-
- StrideA = std::stoi(argv[7]);
- StrideB = std::stoi(argv[8]);
- StrideC = std::stoi(argv[9]);
- }
- else
- {
- printf("arg1: verification (0=no, 1=yes)\n");
- printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
- printf("arg3: time kernel (0=n0, 1=yes)\n");
- printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
- exit(0);
- }
-
- auto f_host_tensor_descriptor =
- [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
- if(std::is_same::value)
- {
- return HostTensorDescriptor(std::vector({row, col}),
- std::vector({stride, 1}));
- }
- else
- {
- return HostTensorDescriptor(std::vector({row, col}),
- std::vector({1, stride}));
- }
- };
-
- Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
- Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
- Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
- Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
-
- // c0_n[n]
- Tensor c0_n(HostTensorDescriptor(
- std::vector({static_cast(N)}), std::vector({1})));
-
- std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
- std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
- std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
- std::cout << "c0_n: " << c0_n.mDesc << std::endl;
-
- switch(init_method)
- {
- case 0: break;
- case 1:
- a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
- b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
- c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
- break;
- default:
- a_m_k.GenerateTensorValue(GeneratorTensor_3