"...composable_kernel.git" did not exist on "570ff3ddbe52d6e1d5e89284d8f3456c0ba34c23"
Commit 49facb91 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files for gemv and tall and skinny gemm examples and corresponding entries to ckprofiler

parent 98fd41f5
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemv_splitk)
add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp)
add_dependencies(example_gemv_splitk
example_gemv_splitk_fp16)
set(target 1)
endif()
endforeach()
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_gemv_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {1, 4608}, strides {4608, 1}
b_k_n: dim 2, lengths {4608, 1104}, strides {1104, 1}
c_m_n: dim 2, lengths {1, 1104}, strides {1104, 1}
Perf: 0.0111038 ms, 0.916305 TFlops, 917.334 GB/s, deviceTsmmDl<64, 1, 128, 3, 4, 1, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 1;
ck::index_t N = 1104;
ck::index_t K = 4608;
ck::index_t stride_A = K;
ck::index_t stride_B = N; // K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row; // Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 4
#define K0 3
#define N1 2
#define B 64 // block-size:64
// clang-format off
using DeviceGemvInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, 1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1, 1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemv_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_gemv_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemv = DeviceGemvInstance{};
auto invoker = gemv.MakeInvoker();
auto argument = gemv.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!gemv.IsSupportedArgument(argument))
{
std::cerr << gemv.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
if(config.do_verification)
{
auto ref_gemv = ReferenceGemmInstance{};
auto ref_invoker = ref_gemv.MakeInvoker();
auto ref_argument = ref_gemv.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_gemv_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_gemv(problem_size, config);
}
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_tall_and_skinny_gemm_splitk)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16)
set(target 1)
endif()
endforeach()
\ No newline at end of file
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_tall_and_skinny_gemm_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {16, 1024}, strides {1024, 1}
b_k_n: dim 2, lengths {1024, 16}, strides {16, 1}
c_m_n: dim 2, lengths {16, 16}, strides {16, 1}
Perf: 0.0065438 ms, 0.0801198 TFlops, 10.0932 GB/s, deviceTsmmDl<64, 16, 128, 4, 2, 16, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 1024;
ck::index_t stride_A = K;
ck::index_t stride_B = N; // K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto tsmm = DeviceTSMMInstance{};
auto invoker = tsmm.MakeInvoker();
auto argument = tsmm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!tsmm.IsSupportedArgument(argument))
{
std::cerr << tsmm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
if(config.do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
auto ref_tsmm = ReferenceGemmInstance{};
auto ref_invoker = ref_tsmm.MakeInvoker();
auto ref_argument = ref_tsmm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< tsmm.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_tall_and_skinny_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row; // Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 2
#define K0 4
#define N1 2
#define B 64 // block-size:64
#define M1 16
// clang-format off
using DeviceTSMMInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_tall_and_skinny_gemm_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_tall_and_skinny_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
index_t MPerThread,
index_t NPerBlock,
index_t K0PerLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
using CIndex = MultiIndex<4>;
static constexpr auto K0 = ABlockDesc_K0_M_K1{}.GetLength(I0);
static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1);
static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2);
static constexpr auto NPerThread = BThreadDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto M0 = M / MPerThread;
static constexpr auto M1 = MPerThread;
static constexpr auto N = NPerBlock;
static constexpr auto N0 = N / NPerThread;
static constexpr auto N1 = NPerThread;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<MPerThread>{}, Number<K1>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<NPerThread>{}, Number<K1>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<I1>{}, Number<M1>{}, Number<I1>{}, Number<N1>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)}
{
static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BThreadDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BThreadDesc_K0_N_K1{}.GetLength(I2),
"wrong! E dimension not consistent\n");
static_assert(K0 % K0PerLoop == 0, "");
static_assert(M % MPerThread == 0 && N % NPerThread == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(BlockSize == M0 * N0, "wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<I1, M1, I1, N1>{};
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
constexpr auto c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, I1, N0, I1))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_m0_m1_n0_n1_thread_cluster_idx =
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
return c_m0_m1_n0_n1_thread_cluster_idx;
}
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{};
static_for<0, K0, K0PerLoop>{}([&](auto k0_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(k0_begin, I0, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0, I0),
a_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(k0_begin, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
}
private:
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_K0_M_K1,
decltype(a_thread_mtx_),
Sequence<K0PerLoop, MPerThread, K1>,
Sequence<0, 1, 2>,
2,
K1,
K1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceTsmm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// GridwiseTsmm
using GridwiseTsmm =
GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
GemmSpec,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerThread,
NPerThread,
KPerThread,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
using Argument = typename GridwiseTsmm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
// const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(karg.k_batch > 1)
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// //
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
return GridwiseTsmm::CheckValidity(arg);
}
else
{
return false;
}
}
// //
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // //
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "deviceTsmmDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerThread << ", "
<< NPerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseTsmm,
typename FloatAB,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_tsmm_dl_v1r3(
typename GridwiseTsmm::Argument karg) //: in __global__ functions, struct is
// better for reduced load overhead
{
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(karg);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
// Argument
struct Argument : public tensor_operation::device::BaseArgument //
{
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
// index_t MPadded_,
// index_t NPadded_,
// index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
// MPadded(MPadded_),
// NPadded(NPadded_),
// KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
}
// private:
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M, N, K;
index_t StrideA, StrideB, StrideC;
//:
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
index_t K0;
index_t k_batch;
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(index_t M, index_t N, index_t k_batch) //
{
const index_t grid_size = math::integer_divide_ceil(N, NPerBlock) *
math::integer_divide_ceil(M, MPerBlock) * k_batch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
static constexpr auto K1Number = Number<K1>{};
// M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(
make_tuple(KBatch, K0, K1Number)), // unmerge is split 1D to 3D
make_right_pad_transform(M, MPad - M)), //
make_tuple(Sequence<1>{}, Sequence<0>{}), // mapped to input M & K; sequence 0 is M;
// 1 is K; make unmerge is working on K;
make_tuple(Sequence<0, 1, 3>{}, // input is M,K; output we want is Kbatch, K0 and K1
// -> 0, 1, 3; output is transformed from 2D to 4D
Sequence<2>{})); // 2->M
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_Kbatch_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1, 1));
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto KBatch_b = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto N_ = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
return (M_ % MPerBlock == 0 && N_ % NPerBlock == 0 && K0_ % K0PerBlock == 0 &&
M_ == c_grid_desc_m_n.GetLength(I0) && N_ == c_grid_desc_m_n.GetLength(I1) &&
a_grid_desc_kbatch_k0_m_k1.GetLength(I3) ==
b_grid_desc_kbatch_k0_n_k1.GetLength(I3) &&
karg.k_batch >= 1 && KBatch_a == karg.k_batch && KBatch_b == karg.k_batch);
}
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
const AGridDesc_Kbatch_K0_M_K1& a_grid_desc_kbatch_k0_m_k1)
{
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_kbatch_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), // IP
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); // OP
return a_grid_desc_kbatch_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(
const BGridDesc_Kbatch_K0_N_K1& b_grid_desc_kbatch_k0_n_k1)
{
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const auto N = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_kbatch_k0_n0_n1_k1 = transform_tensor_descriptor(
b_grid_desc_kbatch_k0_n_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_kbatch_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<MPerThread>{};
constexpr auto N11 = Number<NPerThread>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
{
//: 3d ksplit for C
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; //
using AGridDesc_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(AGridDesc_Kbatch_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(BGridDesc_Kbatch_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); //
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap()); //
template <bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg)
{
constexpr index_t shared_block_size =
GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
const Block2CTileMap& block_2_ctile_map = Block2CTileMap{};
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); //
const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
const auto b_grid_desc_kbatch_k0_n0_n1_k1 =
GridwiseTsmm::MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1); //
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseTsmm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
get_block_1d_id(), karg.N, karg.k_batch);
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_copy_kbatch_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, //: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, // 0, 1, 2, 3, 4
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_kbatch_k0_m0_m1_k1)>, // Global tensor desc
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_kbatch_k0_m0_m1_k1, // for src desc
make_multi_index(kbatch_id, 0, im0, 0, 0), //: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1, // for dst desc
make_multi_index(0, 0, 0, 0, 0));
static constexpr auto b_thread_desc_copy_kbatch_k0_n0_n1_k1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<K0PerBlock>{},
I1,
Number<NPerThread>{},
Number<K1>{})); //: this descriptor is used only for copy
static constexpr auto b_thread_desc_copy_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_kbatch_k0_n0_n1_k1)>,
decltype(b_thread_desc_copy_kbatch_k0_n0_n1_k1), //
Sequence<1, K0PerBlock, 1, NPerThread, K1.value>,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1,
false,
true>(b_grid_desc_kbatch_k0_n0_n1_k1,
make_multi_index(kbatch_id, 0, in0, get_thread_local_1d_id() * NPerThread, 0));
static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
const auto blockwise_tsmm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerBlock,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_tsmm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1,
a_global_buf); // a_global_buf -> reg_tmp_buf
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1,
a_block_even_buf); // reg_tmp_buf->a_block_even_buf
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_tsmm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
namespace ck {
// C[M, N] += transpose(A[M, M]) * B[M, N]
// Element of matrix can be vectorized data
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
typename CThreadDesc_M_N,
typename enable_if<AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto K0 = AThreadDesc_K0_M_K1{}.GetLength(I0);
constexpr auto M = AThreadDesc_K0_M_K1{}.GetLength(I1);
constexpr auto K1 = AThreadDesc_K0_M_K1{}.GetLength(I2);
constexpr auto N = BThreadDesc_K0_N_K1{}.GetLength(I1);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = AThreadDesc_K0_M_K1{}.CalculateOffset(
a_origin_idx + make_tuple(k0, m, k1));
constexpr index_t b_offset = BThreadDesc_K0_N_K1{}.CalculateOffset(
b_origin_idx + make_tuple(k0, n, k1));
constexpr index_t c_offset = CThreadDesc_M_N{}.CalculateOffset(
c_origin_idx + make_tuple(0, m, 0, n));
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
});
});
});
} // namespace ck
};
} // namespace ck
#endif
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
set(GEMV_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_gemv_splitk_instance ${GEMV_SPLITK_INSTANCES})
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1, 1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1, 1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
set(TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_tall_and_skinny_gemm_splitk_instance ${TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES})
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//M1 is always tied to 16
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// //N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances, device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment