Commit 43adf1fa authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

clang format

parent ab3d3b4a
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#ifndef KERNARG_PRELOAD #ifndef KERNARG_PRELOAD
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig &stream_config, float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
...@@ -19,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -19,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if (stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -49,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -49,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
...@@ -81,7 +81,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -81,7 +81,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
#else #else
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig &stream_config, float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
...@@ -92,7 +92,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -92,7 +92,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
// hipGetErrorString(hipMalloc(&args1, sizeof(Args))); // hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice)); // hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if (stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -109,9 +109,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -109,9 +109,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
// //
// warm up // warm up
const int nrepeat = 1000; const int nrepeat = 1000;
for (auto i = 0; i < nrepeat; i++) for(auto i = 0; i < nrepeat; i++)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, hipLaunchKernelGGL(
args...); kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
#if DEBUG_LOG #if DEBUG_LOG
...@@ -127,9 +127,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -127,9 +127,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, hipLaunchKernelGGL(
args...); kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
// hip_check_error(hipGetLastError()); // hip_check_error(hipGetLastError());
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
...@@ -140,8 +140,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -140,8 +140,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
} }
else else
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>( kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
return 0; return 0;
...@@ -155,7 +154,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -155,7 +154,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
} }
#endif #endif
template <typename... Args, typename F, typename PreProcessFunc> template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess, PreProcessFunc preprocess,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
...@@ -164,7 +163,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, ...@@ -164,7 +163,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if (stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -195,7 +194,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, ...@@ -195,7 +194,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
......
...@@ -16,364 +16,465 @@ ...@@ -16,364 +16,465 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck namespace ck {
{ namespace tensor_operation {
namespace tensor_operation namespace device {
{
namespace device
{
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// GridwiseTsmm template <
using GridwiseTsmm = typename ADataType,
GridwiseTsmmDl_km_kn_mn<BlockSize, typename BDataType,
ADataType, typename CDataType,
AccDataType, typename AccDataType,
CDataType, typename ALayout,
ALayout, typename BLayout,
BLayout, typename CLayout,
CLayout, typename AElementwiseOperation,
GemmSpec, typename BElementwiseOperation,
MPerBlock, typename CElementwiseOperation,
NPerBlock, GemmSpecialization GemmSpec,
K0PerBlock, index_t BlockSize,
K1, index_t MPerBlock,
MPerThread, index_t NPerBlock,
NPerThread, index_t K0PerBlock,
KPerThread, index_t K1,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1, index_t MPerThread,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1, index_t NPerThread,
ABlockTransferThreadClusterArrangeOrder, index_t KPerThread,
ABlockTransferSrcAccessOrder, typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, typename ABlockTransferThreadClusterArrangeOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, typename ABlockTransferSrcAccessOrder,
BThreadTransferSrcDstAccessOrder, typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcVectorDim, typename ABlockTransferSrcVectorTensorContiguousDimOrder,
BThreadTransferSrcScalarPerVector, typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
CThreadTransferSrcDstAccessOrder, typename BThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, index_t BThreadTransferSrcVectorDim,
CThreadTransferDstScalarPerVector>; index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap; {
using Argument = typename GridwiseTsmm::Argument; static constexpr auto I0 = Number<0>{};
// Invoker static constexpr auto I1 = Number<1>{};
struct Invoker : public BaseInvoker static constexpr auto I2 = Number<2>{};
{ static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
float Run(const Argument &karg, const StreamConfig &stream_config = StreamConfig{}) // GridwiseTsmm
{ using GridwiseTsmm =
GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
GemmSpec,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerThread,
NPerThread,
KPerThread,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch); using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
const auto b2c_map = DefaultBlock2CTileMap{}; using Argument = typename GridwiseTsmm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
const auto K0 = karg.K0; float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0); const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
const bool has_double_tail_k_block_loop = const auto b2c_map = DefaultBlock2CTileMap{};
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0; const auto K0 = karg.K0;
if (karg.k_batch > 1) const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); const bool has_double_tail_k_block_loop =
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
if (has_main_k_block_loop && has_double_tail_k_block_loop) float ave_time = 0;
{
if (karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, if(karg.k_batch > 1)
ADataType, hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
else if (has_main_k_block_loop && !has_double_tail_k_block_loop)
{
if (karg.k_batch == 1) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, if(karg.k_batch == 1)
ADataType, {
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
else if (!has_main_k_block_loop && has_double_tail_k_block_loop)
{
if (karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
else
{
if (karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
return ave_time; const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
} ADataType,
// polymorphic CDataType,
float BLayout,
Run(const BaseArgument *p_arg, InMemoryDataOperationEnum::Set,
const StreamConfig &stream_config = StreamConfig{}) override true,
{ true,
return Run(*dynamic_cast<const Argument *>(p_arg), stream_config); DefaultBlock2CTileMap>; // //
} ave_time = launch_and_time_kernel(stream_config,
}; kernel,
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
}
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
static constexpr bool IsValidCompilationParameter() if(karg.k_batch == 1)
{ {
// TODO: properly implement this check const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
return true; ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
// // else
static bool IsSupportedArgument(const Argument &arg)
{ {
if (ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ADataType,
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || CDataType,
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942") BLayout,
{ InMemoryDataOperationEnum::AtomicAdd,
return GridwiseTsmm::CheckValidity(arg); true,
} false,
else DefaultBlock2CTileMap>; // //
{ ave_time = launch_and_time_kernel(stream_config,
return false; kernel,
} dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
// // }
// polymorphic else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
bool IsSupportedArgument(const BaseArgument *p_arg) override {
if(karg.k_batch == 1)
{ {
return IsSupportedArgument(*dynamic_cast<const Argument *>(p_arg)); const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
else
static auto MakeArgument(const ADataType *p_a,
const BDataType *p_b,
CDataType *p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{ {
return Argument{p_a, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
p_b, ADataType,
p_c, CDataType,
M, BLayout,
N, InMemoryDataOperationEnum::AtomicAdd,
K, false,
StrideA, true,
StrideB, DefaultBlock2CTileMap>; // //
StrideC, ave_time = launch_and_time_kernel(stream_config,
GridwiseTsmm::CalculateMPadded(M), kernel,
GridwiseTsmm::CalculateNPadded(N), dim3(grid_size),
// GridwiseTsmm::CalculateKPadded(K, KBatch), dim3(BlockSize),
GridwiseTsmm::CalculateK0(K, KBatch), 0,
KBatch}; // // karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
}
static auto MakeInvoker() { return Invoker{}; } else
{
// polymorphic if(karg.k_batch == 1)
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void *p_a,
const void *p_b,
void *p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
return std::make_unique<Argument>(static_cast<const ADataType *>(p_a), ADataType,
static_cast<const BDataType *>(p_b), CDataType,
static_cast<CDataType *>(p_c), BLayout,
M, InMemoryDataOperationEnum::Set,
N, false,
K, false,
StrideA, DefaultBlock2CTileMap>; // //
StrideB, ave_time = launch_and_time_kernel(stream_config,
StrideC, kernel,
GridwiseTsmm::CalculateMPadded(M), dim3(grid_size),
GridwiseTsmm::CalculateNPadded(N), dim3(BlockSize),
// GridwiseTsmm::CalculateKPadded(K, KBatch), 0,
GridwiseTsmm::CalculateK0(K, KBatch), karg.p_a_grid,
KBatch); // // karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
else
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
}
// polymorphic return ave_time;
std::string GetTypeString() const override }
{ // polymorphic
auto str = std::stringstream(); float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
// clang-format off static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// //
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{
return GridwiseTsmm::CheckValidity(arg);
}
else
{
return false;
}
}
// //
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // //
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "deviceTsmmDl" str << "deviceTsmmDl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
...@@ -385,12 +486,12 @@ namespace ck ...@@ -385,12 +486,12 @@ namespace ck
<< NPerThread << ", " << NPerThread << ", "
<< KPerThread << KPerThread
<< ">"; << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
} }
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -16,769 +16,809 @@ ...@@ -16,769 +16,809 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck namespace ck {
{
template <typename GridwiseTsmm,
template <typename GridwiseTsmm, typename FloatAB,
typename FloatAB, typename FloatC,
typename FloatC, typename BLayout,
typename BLayout, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop,
bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop,
bool HasDoubleTailKBlockLoop, typename Block2CTileMap>
typename Block2CTileMap> __global__ void
__global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_tsmm_dl_v1r3( kernel_tsmm_dl_v1r3(
const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K, const FloatAB* p_a_grid,
index_t K0, index_t k_batch, index_t MPadded, index_t NPadded, const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is const FloatAB* p_b_grid,
// better for reduced load overhead FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t K0,
index_t k_batch,
index_t MPadded,
index_t NPadded,
const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is
// better for reduced load overhead
{
// strides depend on B's layout
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
// strides depend on B's layout GridwiseTsmm::template Run<HasMainKBlockLoop,
if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value) HasDoubleTailKBlockLoop,
{ GridwiseTsmm,
GridwiseTsmm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(p_a_grid,
HasDoubleTailKBlockLoop, p_b_grid,
GridwiseTsmm, p_c_grid,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K, M,
K0, k_batch, K, N, N, MPadded, NPadded, block_2_ctile_map); N,
} K,
else K0,
{ k_batch,
GridwiseTsmm::template Run<HasMainKBlockLoop, K,
HasDoubleTailKBlockLoop, N,
GridwiseTsmm, N,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K, MPadded,
K0, k_batch, K, K, N, MPadded, NPadded, block_2_ctile_map); NPadded,
} block_2_ctile_map);
} }
else
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn
{ {
static constexpr auto I0 = Number<0>{}; GridwiseTsmm::template Run<HasMainKBlockLoop,
static constexpr auto I1 = Number<1>{}; HasDoubleTailKBlockLoop,
static constexpr auto I2 = Number<2>{}; GridwiseTsmm,
static constexpr auto I3 = Number<3>{}; CGlobalMemoryDataOperation>(p_a_grid,
p_b_grid,
p_c_grid,
M,
N,
K,
K0,
k_batch,
K,
K,
N,
MPadded,
NPadded,
block_2_ctile_map);
}
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument // struct Argument : public tensor_operation::device::BaseArgument //
{
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
// index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded(MPadded_),
NPadded(NPadded_),
// KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{ {
Argument(const FloatAB *p_a_grid_, }
const FloatAB *p_b_grid_,
FloatC *p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
// index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded(MPadded_),
NPadded(NPadded_),
// KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
}
// private: // private:
const FloatAB *p_a_grid; const FloatAB* p_a_grid;
const FloatAB *p_b_grid; const FloatAB* p_b_grid;
FloatC *p_c_grid; FloatC* p_c_grid;
index_t M, N, K; index_t M, N, K;
index_t StrideA, StrideB, StrideC; index_t StrideA, StrideB, StrideC;
//: //:
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
// index_t KPadded; // index_t KPadded;
index_t K0; index_t K0;
index_t k_batch; index_t k_batch;
}; };
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
// A matrix in LDS memory, dst of blockwise copy {
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( // TODO: change this. I think it needs multi-dimensional alignment
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); constexpr auto max_lds_align = K1;
// TODO: check alignment // TODO: check alignment
// LDS allocation for A and B: be careful of alignment // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_aligned_space_size = constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB); // TODO: check alignment
} // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
__host__ __device__ static constexpr index_t return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
CalculateGridSize(index_t M, index_t N, index_t k_batch) // }
{
const index_t grid_size = math::integer_divide_ceil(N, NPerBlock) *
math::integer_divide_ceil(M, MPerBlock) * k_batch;
return grid_size; __host__ __device__ static constexpr index_t
} CalculateGridSize(index_t M, index_t N, index_t k_batch) //
{
const index_t grid_size = math::integer_divide_ceil(N, NPerBlock) *
math::integer_divide_ceil(M, MPerBlock) * k_batch;
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) return grid_size;
{ }
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop; __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
} {
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) return has_main_k_block_loop;
}
{ __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop; {
} const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
__host__ __device__ static auto CalculateMPadded(index_t M) return has_double_tail_k_block_loop;
{ }
return math::integer_least_multiple(M, MPerBlock);
}
__host__ __device__ static auto CalculateNPadded(index_t N) __host__ __device__ static auto CalculateMPadded(index_t M)
{ {
return math::integer_least_multiple(N, NPerBlock); return math::integer_least_multiple(M, MPerBlock);
} }
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateNPadded(index_t N)
{ {
// k_batch * k0 * k0_per_block * k1 return math::integer_least_multiple(N, NPerBlock);
auto K_t = K_Batch * K0PerBlock * K1; }
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{ {
auto K0 = CalculateK0(K, K_Batch); // k_batch * k0 * k0_per_block * k1
return K_Batch * K0 * K1; auto K_t = K_Batch * K0PerBlock * K1;
} return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
// M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1 // M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1( __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0) index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{ {
const auto a_grid_desc_m_k = [&]() const auto a_grid_desc_m_k = [&]() {
{ if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
if constexpr (is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(
make_tuple(KBatch, K0, K1Number)), // unmerge is split 1D to 3D
make_right_pad_transform(M, MPad - M)), //
make_tuple(Sequence<1>{}, Sequence<0>{}), // mapped to input M & K; sequence 0 is M;
// 1 is K; make unmerge is working on K;
make_tuple(Sequence<0, 1, 3>{}, // input is M,K; output we want is Kbatch, K0 and K1
// -> 0, 1, 3; output is transformed from 2D to 4D
Sequence<2>{})); // 2->M
} }
else else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return transform_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
} }();
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1( if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{ {
const auto b_grid_desc_k_n = [&]() return transform_tensor_descriptor(
{ a_grid_desc_m_k,
if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value) make_tuple(make_unmerge_transform(
{ make_tuple(KBatch, K0, K1Number)), // unmerge is split 1D to 3D
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); make_right_pad_transform(M, MPad - M)), //
} make_tuple(Sequence<1>{}, Sequence<0>{}), // mapped to input M & K; sequence 0 is M;
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) // 1 is K; make unmerge is working on K;
{ make_tuple(Sequence<0, 1, 3>{}, // input is M,K; output we want is Kbatch, K0 and K1
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); // -> 0, 1, 3; output is transformed from 2D to 4D
} Sequence<2>{})); // 2->M
}();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
} }
else
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() return transform_tensor_descriptor(
{ a_grid_desc_m_k,
if constexpr (is_same<tensor_layout::gemm::RowMajor, CLayout>::value) make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
{ make_pass_through_transform(M)),
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); make_tuple(Sequence<1>{}, Sequence<0>{}),
} make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) }
{ }
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
} __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
}(); index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
} }();
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_Kbatch_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1, 1));
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
__host__ __device__ static constexpr bool CheckValidity(const Argument &karg) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
// const auto MPadded = CalculateMPadded(karg.M); return transform_tensor_descriptor(
// const auto NPadded = CalculateNPadded(karg.N); b_grid_desc_k_n,
const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1( make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); make_right_pad_transform(N, NPad - N)),
const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1( make_tuple(Sequence<0>{}, Sequence<1>{}),
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto KBatch_b = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto N_ = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
return (M_ % MPerBlock == 0 && N_ % NPerBlock == 0 && K0_ % K0PerBlock == 0 &&
M_ == c_grid_desc_m_n.GetLength(I0) && N_ == c_grid_desc_m_n.GetLength(I1) &&
a_grid_desc_kbatch_k0_m_k1.GetLength(I3) ==
b_grid_desc_kbatch_k0_n_k1.GetLength(I3) &&
karg.k_batch >= 1 && KBatch_a == karg.k_batch && KBatch_b == karg.k_batch);
} }
else
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
const AGridDesc_Kbatch_K0_M_K1 &a_grid_desc_kbatch_k0_m_k1)
{ {
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0); return transform_tensor_descriptor(
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); b_grid_desc_k_n,
const auto M = a_grid_desc_kbatch_k0_m_k1.GetLength(I2); make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
const auto M1 = Number<MPerBlock>{}; make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto M0 = M / M1; make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_kbatch_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), // IP
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); // OP
return a_grid_desc_kbatch_k0_m0_m1_k1;
} }
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1( __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
const BGridDesc_Kbatch_K0_N_K1 &b_grid_desc_kbatch_k0_n_k1) {
{ const auto c_grid_desc_m_n = [&]() {
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0); if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); {
const auto N = b_grid_desc_kbatch_k0_n_k1.GetLength(I2); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
const auto N1 = Number<NPerBlock>{}; else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
const auto N0 = N / N1; {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
const auto b_grid_desc_kbatch_k0_n0_n1_k1 = transform_tensor_descriptor( }
b_grid_desc_kbatch_k0_n_k1, }();
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_kbatch_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto N = c_grid_desc_m_n.GetLength(I1); const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<MPerThread>{};
constexpr auto N11 = Number<NPerThread>{};
constexpr auto M10 = M1 / M11; return transform_tensor_descriptor(
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
else
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
{ {
//: 3d ksplit for C
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>(); return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; // }
using AGridDesc_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(AGridDesc_Kbatch_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(BGridDesc_Kbatch_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); //
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap()); //
template <bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K,
index_t K0, index_t k_batch, index_t StrideA, index_t StrideB, index_t StrideC, index_t MPadded, index_t NPadded, const Block2CTileMap &block_2_ctile_map)
{
constexpr index_t shared_block_size = __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); {
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
__shared__ FloatAB p_shared_block[shared_block_size]; const index_t KPad = KBatch * K0 * K1;
return KPad;
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1( }
M, MPadded, K, StrideA, k_batch, K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
K, NPadded, N, StrideB, k_batch, K0); //
const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
const auto b_grid_desc_kbatch_k0_n0_n1_k1 =
GridwiseTsmm::MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1); //
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseTsmm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
get_block_1d_id(), N, k_batch);
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if (!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment using AGridDesc_Kbatch_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1, 1));
constexpr auto max_lds_align = K1; using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
constexpr auto a_block_desc_copy_kbatch_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, //: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, // 0, 1, 2, 3, 4
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_kbatch_k0_m0_m1_k1)>, // Global tensor desc
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_kbatch_k0_m0_m1_k1, // for src desc
make_multi_index(kbatch_id, 0, im0, 0, 0), //: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1, // for dst desc
make_multi_index(0, 0, 0, 0, 0));
static constexpr auto b_thread_desc_copy_kbatch_k0_n0_n1_k1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<K0PerBlock>{},
I1,
Number<NPerThread>{},
Number<K1>{})); //: this descriptor is used only for copy
static constexpr auto b_thread_desc_copy_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_kbatch_k0_n0_n1_k1)>,
decltype(b_thread_desc_copy_kbatch_k0_n0_n1_k1), //
Sequence<1, K0PerBlock, 1, NPerThread, K1.value>,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1,
false,
true>(b_grid_desc_kbatch_k0_n0_n1_k1,
make_multi_index(kbatch_id, 0, in0, get_thread_local_1d_id() * NPerThread, 0));
static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed( __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{})); {
// TODO: check alignment // const auto MPadded = CalculateMPadded(karg.M);
// A matrix in LDS memory, dst of blockwise copy // const auto NPadded = CalculateNPadded(karg.N);
// be careful of LDS alignment const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align); const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto KBatch_b = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto N_ = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
return (M_ % MPerBlock == 0 && N_ % NPerBlock == 0 && K0_ % K0PerBlock == 0 &&
M_ == c_grid_desc_m_n.GetLength(I0) && N_ == c_grid_desc_m_n.GetLength(I1) &&
a_grid_desc_kbatch_k0_m_k1.GetLength(I3) ==
b_grid_desc_kbatch_k0_n_k1.GetLength(I3) &&
karg.k_batch >= 1 && KBatch_a == karg.k_batch && KBatch_b == karg.k_batch);
}
// TODO: check alignment // KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
// A matrix in LDS memory, for blockwise GEMM __host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( const AGridDesc_Kbatch_K0_M_K1& a_grid_desc_kbatch_k0_m_k1)
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); {
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_kbatch_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), // IP
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); // OP
return a_grid_desc_kbatch_k0_m0_m1_k1;
}
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == __host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(
a_k0_m_k1_block_desc.GetElementSpaceSize() && const BGridDesc_Kbatch_K0_N_K1& b_grid_desc_kbatch_k0_n_k1)
"wrong!"); {
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const auto N = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_kbatch_k0_n0_n1_k1 = transform_tensor_descriptor(
b_grid_desc_kbatch_k0_n_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_kbatch_k0_n0_n1_k1;
}
const auto blockwise_tsmm = __host__ __device__ static constexpr auto
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
FloatAB, {
FloatAB, const auto M = c_grid_desc_m_n.GetLength(I0);
FloatAcc, const auto N = c_grid_desc_m_n.GetLength(I1);
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerBlock,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto M1 = Number<MPerBlock>{};
decltype(blockwise_tsmm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); constexpr auto N1 = Number<NPerBlock>{};
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( const auto M0 = M / M1;
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); const auto N0 = N / N1;
// LDS allocation for A and B: be careful of alignment constexpr auto M11 = Number<MPerThread>{};
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto N11 = Number<NPerThread>{};
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB *p_a_block_double = p_shared_block; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
b_k0_n_k1_thread_desc.GetElementSpaceSize()); c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( return c_grid_desc_m0_m10_m11_n0_n10_n11;
b_k0_n_k1_thread_desc.GetElementSpaceSize()); }
// register allocation for output // return block_id to C matrix tile idx (m0, n0) mapping
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); {
//: 3d ksplit for C
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; //
using AGridDesc_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(AGridDesc_Kbatch_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(BGridDesc_Kbatch_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); //
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap()); //
template <bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t K0,
index_t k_batch,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t MPadded,
index_t NPadded,
const Block2CTileMap& block_2_ctile_map)
{
// Initialize C constexpr index_t shared_block_size =
c_thread_buf.Clear(); GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
M, MPadded, K, StrideA, k_batch, K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
K, NPadded, N, StrideB, k_batch, K0); //
const auto c_grid_desc_m_n = GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
const auto b_grid_desc_kbatch_k0_n0_n1_k1 =
GridwiseTsmm::MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1); //
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseTsmm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(get_block_1d_id(), N, k_batch);
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0); // TODO: change this. I think it needs multi-dimensional alignment
constexpr auto b_thread_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0); constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_copy_kbatch_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, //: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, // 0, 1, 2, 3, 4
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_kbatch_k0_m0_m1_k1)>, // Global tensor desc
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_kbatch_k0_m0_m1_k1, // for src desc
make_multi_index(kbatch_id, 0, im0, 0, 0), //: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1, // for dst desc
make_multi_index(0, 0, 0, 0, 0));
static constexpr auto b_thread_desc_copy_kbatch_k0_n0_n1_k1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<K0PerBlock>{},
I1,
Number<NPerThread>{},
Number<K1>{})); //: this descriptor is used only for copy
static constexpr auto b_thread_desc_copy_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_kbatch_k0_n0_n1_k1)>,
decltype(b_thread_desc_copy_kbatch_k0_n0_n1_k1), //
Sequence<1, K0PerBlock, 1, NPerThread, K1.value>,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1,
false,
true>(b_grid_desc_kbatch_k0_n0_n1_k1,
make_multi_index(kbatch_id, 0, in0, get_thread_local_1d_id() * NPerThread, 0));
static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
const auto blockwise_tsmm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerBlock,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_tsmm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1,
a_global_buf); // a_global_buf -> reg_tmp_buf
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1,
a_block_even_buf); // reg_tmp_buf->a_block_even_buf
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
}
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( if constexpr(HasMainKBlockLoop)
p_a_block_double, a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize()); {
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( index_t k_block_data_begin = 0;
p_a_block_double + a_block_aligned_space_size,
a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{ {
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, // even iteration
a_global_buf); // a_global_buf -> reg_tmp_buf a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_slice_copy_step);
a_block_even_buf); // reg_tmp_buf->a_block_even_buf
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf, b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1, b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf); b_thread_odd_buf);
}
if constexpr (HasMainKBlockLoop)
{
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data block_sync_lds();
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: GEMM on current data
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf); blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
k_block_data_begin += 2 * K0PerBlock; // LDS double buffer: store next data to LDS
} while (k_block_data_begin < K0 - 2 * K0PerBlock); a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
}
// LDS double buffer: tail // odd iteration
if constexpr (HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step); b_thread_slice_copy_step);
block_sync_lds(); // LDS doubel buffer: load next data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf, b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1, b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf); b_thread_even_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds(); block_sync_lds();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf); blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: store next data to LDS
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf); a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf);
}
// output: register to global memory k_block_data_begin += 2 * K0PerBlock;
{ } while(k_block_data_begin < K0 - 2 * K0PerBlock);
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_tsmm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
} }
};
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_tsmm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment