Commit 49facb91 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files for gemv and tall and skinny gemm examples and corresponding entries to ckprofiler

parent 98fd41f5
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemv_splitk)
add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp)
add_dependencies(example_gemv_splitk
example_gemv_splitk_fp16)
set(target 1)
endif()
endforeach()
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_gemv_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {1, 4608}, strides {4608, 1}
b_k_n: dim 2, lengths {4608, 1104}, strides {1104, 1}
c_m_n: dim 2, lengths {1, 1104}, strides {1104, 1}
Perf: 0.0111038 ms, 0.916305 TFlops, 917.334 GB/s, deviceTsmmDl<64, 1, 128, 3, 4, 1, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 1;
ck::index_t N = 1104;
ck::index_t K = 4608;
ck::index_t stride_A = K;
ck::index_t stride_B = N; // K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row; // Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 4
#define K0 3
#define N1 2
#define B 64 // block-size:64
// clang-format off
using DeviceGemvInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, 1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1, 1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemv_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_gemv_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemv = DeviceGemvInstance{};
auto invoker = gemv.MakeInvoker();
auto argument = gemv.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!gemv.IsSupportedArgument(argument))
{
std::cerr << gemv.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
if(config.do_verification)
{
auto ref_gemv = ReferenceGemmInstance{};
auto ref_invoker = ref_gemv.MakeInvoker();
auto ref_argument = ref_gemv.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_gemv_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_gemv(problem_size, config);
}
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_tall_and_skinny_gemm_splitk)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16)
set(target 1)
endif()
endforeach()
\ No newline at end of file
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_tall_and_skinny_gemm_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {16, 1024}, strides {1024, 1}
b_k_n: dim 2, lengths {1024, 16}, strides {16, 1}
c_m_n: dim 2, lengths {16, 16}, strides {16, 1}
Perf: 0.0065438 ms, 0.0801198 TFlops, 10.0932 GB/s, deviceTsmmDl<64, 16, 128, 4, 2, 16, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 1024;
ck::index_t stride_A = K;
ck::index_t stride_B = N; // K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto tsmm = DeviceTSMMInstance{};
auto invoker = tsmm.MakeInvoker();
auto argument = tsmm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!tsmm.IsSupportedArgument(argument))
{
std::cerr << tsmm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
if(config.do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
auto ref_tsmm = ReferenceGemmInstance{};
auto ref_invoker = ref_tsmm.MakeInvoker();
auto ref_argument = ref_tsmm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< tsmm.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_tall_and_skinny_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row; // Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 2
#define K0 4
#define N1 2
#define B 64 // block-size:64
#define M1 16
// clang-format off
using DeviceTSMMInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_tall_and_skinny_gemm_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_tall_and_skinny_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
index_t MPerThread,
index_t NPerBlock,
index_t K0PerLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
using CIndex = MultiIndex<4>;
static constexpr auto K0 = ABlockDesc_K0_M_K1{}.GetLength(I0);
static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1);
static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2);
static constexpr auto NPerThread = BThreadDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto M0 = M / MPerThread;
static constexpr auto M1 = MPerThread;
static constexpr auto N = NPerBlock;
static constexpr auto N0 = N / NPerThread;
static constexpr auto N1 = NPerThread;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<MPerThread>{}, Number<K1>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<NPerThread>{}, Number<K1>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<I1>{}, Number<M1>{}, Number<I1>{}, Number<N1>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)}
{
static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BThreadDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BThreadDesc_K0_N_K1{}.GetLength(I2),
"wrong! E dimension not consistent\n");
static_assert(K0 % K0PerLoop == 0, "");
static_assert(M % MPerThread == 0 && N % NPerThread == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(BlockSize == M0 * N0, "wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<I1, M1, I1, N1>{};
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
constexpr auto c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, I1, N0, I1))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_m0_m1_n0_n1_thread_cluster_idx =
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
return c_m0_m1_n0_n1_thread_cluster_idx;
}
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{};
static_for<0, K0, K0PerLoop>{}([&](auto k0_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(k0_begin, I0, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0, I0),
a_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(k0_begin, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
}
private:
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_K0_M_K1,
decltype(a_thread_mtx_),
Sequence<K0PerLoop, MPerThread, K1>,
Sequence<0, 1, 2>,
2,
K1,
K1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceTsmm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// GridwiseTsmm
using GridwiseTsmm =
GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
GemmSpec,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerThread,
NPerThread,
KPerThread,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
using Argument = typename GridwiseTsmm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
// const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(karg.k_batch > 1)
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// //
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
return GridwiseTsmm::CheckValidity(arg);
}
else
{
return false;
}
}
// //
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // //
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "deviceTsmmDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerThread << ", "
<< NPerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseTsmm,
typename FloatAB,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_tsmm_dl_v1r3(
typename GridwiseTsmm::Argument karg) //: in __global__ functions, struct is
// better for reduced load overhead
{
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(karg);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
// Argument
struct Argument : public tensor_operation::device::BaseArgument //
{
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
// index_t MPadded_,
// index_t NPadded_,
// index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
// MPadded(MPadded_),
// NPadded(NPadded_),
// KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
}
// private:
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M, N, K;
index_t StrideA, StrideB, StrideC;
//:
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
index_t K0;
index_t k_batch;
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(index_t M, index_t N, index_t k_batch) //
{
const index_t grid_size = math::integer_divide_ceil(N, NPerBlock) *
math::integer_divide_ceil(M, MPerBlock) * k_batch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
static constexpr auto K1Number = Number<K1>{};
// M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(
make_tuple(KBatch, K0, K1Number)), // unmerge is split 1D to 3D
make_right_pad_transform(M, MPad - M)), //
make_tuple(Sequence<1>{}, Sequence<0>{}), // mapped to input M & K; sequence 0 is M;
// 1 is K; make unmerge is working on K;
make_tuple(Sequence<0, 1, 3>{}, // input is M,K; output we want is Kbatch, K0 and K1
// -> 0, 1, 3; output is transformed from 2D to 4D
Sequence<2>{})); // 2->M
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_Kbatch_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1, 1));
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto KBatch_b = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto N_ = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
return (M_ % MPerBlock == 0 && N_ % NPerBlock == 0 && K0_ % K0PerBlock == 0 &&
M_ == c_grid_desc_m_n.GetLength(I0) && N_ == c_grid_desc_m_n.GetLength(I1) &&
a_grid_desc_kbatch_k0_m_k1.GetLength(I3) ==
b_grid_desc_kbatch_k0_n_k1.GetLength(I3) &&
karg.k_batch >= 1 && KBatch_a == karg.k_batch && KBatch_b == karg.k_batch);
}
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
const AGridDesc_Kbatch_K0_M_K1& a_grid_desc_kbatch_k0_m_k1)
{
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_kbatch_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), // IP
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); // OP
return a_grid_desc_kbatch_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(
const BGridDesc_Kbatch_K0_N_K1& b_grid_desc_kbatch_k0_n_k1)
{
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const auto N = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_kbatch_k0_n0_n1_k1 = transform_tensor_descriptor(
b_grid_desc_kbatch_k0_n_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_kbatch_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<MPerThread>{};
constexpr auto N11 = Number<NPerThread>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
{
//: 3d ksplit for C
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; //
using AGridDesc_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(AGridDesc_Kbatch_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(BGridDesc_Kbatch_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); //
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap()); //
template <bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg)
{
constexpr index_t shared_block_size =
GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
const Block2CTileMap& block_2_ctile_map = Block2CTileMap{};
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); //
const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
const auto b_grid_desc_kbatch_k0_n0_n1_k1 =
GridwiseTsmm::MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1); //
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseTsmm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
get_block_1d_id(), karg.N, karg.k_batch);
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_copy_kbatch_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, //: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, // 0, 1, 2, 3, 4
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_kbatch_k0_m0_m1_k1)>, // Global tensor desc
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_kbatch_k0_m0_m1_k1, // for src desc
make_multi_index(kbatch_id, 0, im0, 0, 0), //: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1, // for dst desc
make_multi_index(0, 0, 0, 0, 0));
static constexpr auto b_thread_desc_copy_kbatch_k0_n0_n1_k1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<K0PerBlock>{},
I1,
Number<NPerThread>{},
Number<K1>{})); //: this descriptor is used only for copy
static constexpr auto b_thread_desc_copy_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_kbatch_k0_n0_n1_k1)>,
decltype(b_thread_desc_copy_kbatch_k0_n0_n1_k1), //
Sequence<1, K0PerBlock, 1, NPerThread, K1.value>,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1,
false,
true>(b_grid_desc_kbatch_k0_n0_n1_k1,
make_multi_index(kbatch_id, 0, in0, get_thread_local_1d_id() * NPerThread, 0));
static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
const auto blockwise_tsmm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerBlock,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_tsmm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1,
a_global_buf); // a_global_buf -> reg_tmp_buf
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1,
a_block_even_buf); // reg_tmp_buf->a_block_even_buf
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_tsmm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
namespace ck {
// C[M, N] += transpose(A[M, M]) * B[M, N]
// Element of matrix can be vectorized data
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
typename CThreadDesc_M_N,
typename enable_if<AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto K0 = AThreadDesc_K0_M_K1{}.GetLength(I0);
constexpr auto M = AThreadDesc_K0_M_K1{}.GetLength(I1);
constexpr auto K1 = AThreadDesc_K0_M_K1{}.GetLength(I2);
constexpr auto N = BThreadDesc_K0_N_K1{}.GetLength(I1);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = AThreadDesc_K0_M_K1{}.CalculateOffset(
a_origin_idx + make_tuple(k0, m, k1));
constexpr index_t b_offset = BThreadDesc_K0_N_K1{}.CalculateOffset(
b_origin_idx + make_tuple(k0, n, k1));
constexpr index_t c_offset = CThreadDesc_M_N{}.CalculateOffset(
c_origin_idx + make_tuple(0, m, 0, n));
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
});
});
});
} // namespace ck
};
} // namespace ck
#endif
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
set(GEMV_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_gemv_splitk_instance ${GEMV_SPLITK_INSTANCES})
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1, 1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1, 1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 1, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 2, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 3, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 4, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 5, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 6, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 7, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 2, 1, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 4, 1, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 128, 8, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 1, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 2, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 3, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 4, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 5, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 6, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 7, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 2, 1, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 4, 1, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 256, 8, 8, 1, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 1, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 2, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 3, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 4, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 5, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 6, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 7, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 2, 1, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 4, 1, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 1, 512, 8, 8, 1, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
set(TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_tall_and_skinny_gemm_splitk_instance ${TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES})
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//M1 is always tied to 16
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// //N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances, device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment