Commit f88c2f86 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

kernarg load latency optimization for mi300

parent c2784145
...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp) add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp)
add_dependencies(example_gemv_splitk add_dependencies(example_gemv_splitk
example_gemv_splitk_fp16) example_gemv_splitk_fp16)
set_source_files_properties(gemv_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS "-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp) add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
add_dependencies(example_tall_and_skinny_gemm_splitk add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16) example_tall_and_skinny_gemm_splitk_fp16)
set_source_files_properties(tall_and_skinny_gemm_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS "-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#ifndef KERNARG_PRELOAD
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config, float launch_and_time_kernel(const StreamConfig &stream_config,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
...@@ -18,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -18,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if (stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -48,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -48,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for (int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
...@@ -78,8 +79,83 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -78,8 +79,83 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif #endif
} }
#else
template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig &stream_config,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
// Args* args1;
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
if (stream_config.time_kernel_)
{
#if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
//
// warm up
const int nrepeat = 1000;
for (auto i = 0; i < nrepeat; i++)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_,
args...);
hip_check_error(hipGetLastError());
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
float total_time = 0;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_,
args...);
// hip_check_error(hipGetLastError());
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat;
}
else
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(
args...);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
}
#endif
template <typename... Args, typename F, typename PreProcessFunc> template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
PreProcessFunc preprocess, PreProcessFunc preprocess,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
...@@ -88,7 +164,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -88,7 +164,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if (stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -119,7 +195,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -119,7 +195,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for (int i = 0; i < nrepeat; ++i)
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
......
...@@ -16,345 +16,364 @@ ...@@ -16,345 +16,364 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck
namespace tensor_operation { {
namespace device { namespace tensor_operation
{
namespace device
{
template < template <
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
index_t K1, index_t K1,
index_t MPerThread, index_t MPerThread,
index_t NPerThread, index_t NPerThread,
index_t KPerThread, index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1, typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1, typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder, typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder, typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim, index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector, index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout, struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
// GridwiseTsmm // GridwiseTsmm
using GridwiseTsmm = using GridwiseTsmm =
GridwiseTsmmDl_km_kn_mn<BlockSize, GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType, ADataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
GemmSpec, GemmSpec,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
K1, K1,
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcDstAccessOrder, BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim, BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector, BThreadTransferSrcScalarPerVector,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap; using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
using Argument = typename GridwiseTsmm::Argument; using Argument = typename GridwiseTsmm::Argument;
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument &karg, const StreamConfig &stream_config = StreamConfig{})
{ {
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch); const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
// const auto b2c_map = DefaultBlock2CTileMap{}; const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0; const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0); const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = const bool has_double_tail_k_block_loop =
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0); GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(karg.k_batch > 1) if (karg.k_batch > 1)
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if(has_main_k_block_loop && has_double_tail_k_block_loop) if (has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
if(karg.k_batch == 1) if (karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, BLayout,
true, InMemoryDataOperationEnum::Set,
true, true,
DefaultBlock2CTileMap>; // // true,
ave_time = launch_and_time_kernel( DefaultBlock2CTileMap>; // //
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); ave_time = launch_and_time_kernel(
} stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
else (karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
{ }
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, else
ADataType, {
CDataType, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
InMemoryDataOperationEnum::AtomicAdd, ADataType,
true, CDataType,
true, BLayout,
DefaultBlock2CTileMap>; // // InMemoryDataOperationEnum::AtomicAdd,
ave_time = launch_and_time_kernel( true,
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); true,
} DefaultBlock2CTileMap>; // //
} ave_time = launch_and_time_kernel(
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
{ (karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
else if (has_main_k_block_loop && !has_double_tail_k_block_loop)
{
if(karg.k_batch == 1) if (karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, BLayout,
true, InMemoryDataOperationEnum::Set,
false, true,
DefaultBlock2CTileMap>; // // false,
ave_time = launch_and_time_kernel( DefaultBlock2CTileMap>; // //
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); ave_time = launch_and_time_kernel(
} stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
else (karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
{ }
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, else
ADataType, {
CDataType, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
InMemoryDataOperationEnum::AtomicAdd, ADataType,
true, CDataType,
false, BLayout,
DefaultBlock2CTileMap>; // // InMemoryDataOperationEnum::AtomicAdd,
ave_time = launch_and_time_kernel( true,
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); false,
} DefaultBlock2CTileMap>; // //
} ave_time = launch_and_time_kernel(
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
{ (karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
if(karg.k_batch == 1) }
}
else if (!has_main_k_block_loop && has_double_tail_k_block_loop)
{
if (karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
else
{
if (karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
}
}
return ave_time;
}
// polymorphic
float
Run(const BaseArgument *p_arg,
const StreamConfig &stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument *>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, // TODO: properly implement this check
ADataType, return true;
CDataType,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
} }
else // //
static bool IsSupportedArgument(const Argument &arg)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, if (ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ADataType, ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
CDataType, ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
InMemoryDataOperationEnum::AtomicAdd, ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
false, {
true, return GridwiseTsmm::CheckValidity(arg);
DefaultBlock2CTileMap>; // // }
ave_time = launch_and_time_kernel( else
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); {
return false;
}
} }
} // //
else // polymorphic
{ bool IsSupportedArgument(const BaseArgument *p_arg) override
if(karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, return IsSupportedArgument(*dynamic_cast<const Argument *>(p_arg));
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
} }
else
static auto MakeArgument(const ADataType *p_a,
const BDataType *p_b,
CDataType *p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, return Argument{p_a,
ADataType, p_b,
CDataType, p_c,
InMemoryDataOperationEnum::AtomicAdd, M,
false, N,
false, K,
DefaultBlock2CTileMap>; // // StrideA,
ave_time = launch_and_time_kernel( StrideB,
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); StrideC,
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
} }
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter() static auto MakeInvoker() { return Invoker{}; }
{
// TODO: properly implement this check
return true;
}
// //
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
return GridwiseTsmm::CheckValidity(arg);
}
else
{
return false;
}
}
// //
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
}
static auto MakeInvoker() { return Invoker{}; } // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void *p_a,
// polymorphic const void *p_b,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, void *p_c,
const void* p_b, index_t M,
void* p_c, index_t N,
index_t M, index_t K,
index_t N, index_t StrideA,
index_t K, index_t StrideB,
index_t StrideA, index_t StrideC,
index_t StrideB, AElementwiseOperation,
index_t StrideC, BElementwiseOperation,
AElementwiseOperation, CElementwiseOperation,
BElementwiseOperation, ck::index_t KBatch = 1) override // //
CElementwiseOperation, {
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType *>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType *>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType *>(p_c),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
// GridwiseTsmm::CalculateMPadded(M), GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N), GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch), // GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // // KBatch); // //
} }
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic // polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "deviceTsmmDl" str << "deviceTsmmDl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
...@@ -366,12 +385,12 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -366,12 +385,12 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
<< NPerThread << ", " << NPerThread << ", "
<< KPerThread << KPerThread
<< ">"; << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
} }
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -16,757 +16,769 @@ ...@@ -16,757 +16,769 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck
{
template <typename GridwiseTsmm,
typename FloatAB, template <typename GridwiseTsmm,
typename FloatC, typename FloatAB,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename FloatC,
bool HasMainKBlockLoop, typename BLayout,
bool HasDoubleTailKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap> bool HasMainKBlockLoop,
__global__ void bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_tsmm_dl_v1r3( kernel_tsmm_dl_v1r3(
typename GridwiseTsmm::Argument karg) //: in __global__ functions, struct is const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K,
// better for reduced load overhead index_t K0, index_t k_batch, index_t MPadded, index_t NPadded, const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is
{ // better for reduced load overhead
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(karg);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
// Argument
struct Argument : public tensor_operation::device::BaseArgument //
{ {
Argument(const FloatAB* p_a_grid_, // strides depend on B's layout
const FloatAB* p_b_grid_, if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
// index_t MPadded_,
// index_t NPadded_,
// index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
// MPadded(MPadded_),
// NPadded(NPadded_),
// KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{ {
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K,
K0, k_batch, K, N, N, MPadded, NPadded, block_2_ctile_map);
}
else
{
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K,
K0, k_batch, K, K, N, MPadded, NPadded, block_2_ctile_map);
} }
// private:
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M, N, K;
index_t StrideA, StrideB, StrideC;
//:
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
index_t K0;
index_t k_batch;
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr index_t template <index_t BlockSize,
CalculateGridSize(index_t M, index_t N, index_t k_batch) // typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn
{ {
const index_t grid_size = math::integer_divide_ceil(N, NPerBlock) * static constexpr auto I0 = Number<0>{};
math::integer_divide_ceil(M, MPerBlock) * k_batch; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
return grid_size; // K1 should be Number<...>
} static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) // Argument
{ struct Argument : public tensor_operation::device::BaseArgument //
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; {
Argument(const FloatAB *p_a_grid_,
const FloatAB *p_b_grid_,
FloatC *p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
// index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded(MPadded_),
NPadded(NPadded_),
// KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
}
return has_main_k_block_loop; // private:
} const FloatAB *p_a_grid;
const FloatAB *p_b_grid;
FloatC *p_c_grid;
index_t M, N, K;
index_t StrideA, StrideB, StrideC;
//:
index_t MPadded;
index_t NPadded;
// index_t KPadded;
index_t K0;
index_t k_batch;
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
{ // TODO: check alignment
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
return has_double_tail_k_block_loop; return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static auto CalculateMPadded(index_t M) __host__ __device__ static constexpr index_t
{ CalculateGridSize(index_t M, index_t N, index_t k_batch) //
return math::integer_least_multiple(M, MPerBlock); {
} const index_t grid_size = math::integer_divide_ceil(N, NPerBlock) *
math::integer_divide_ceil(M, MPerBlock) * k_batch;
__host__ __device__ static auto CalculateNPadded(index_t N) return grid_size;
{ }
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{ {
// k_batch * k0 * k0_per_block * k1 const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) return has_main_k_block_loop;
{ }
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
static constexpr auto K1Number = Number<K1>{}; __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
// M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1 {
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1( const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{
const auto a_grid_desc_m_k = [&]() { return has_double_tail_k_block_loop;
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) }
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) __host__ __device__ static auto CalculateMPadded(index_t M)
{ {
return math::integer_least_multiple(M, MPerBlock);
}
return transform_tensor_descriptor( __host__ __device__ static auto CalculateNPadded(index_t N)
a_grid_desc_m_k, {
make_tuple(make_unmerge_transform( return math::integer_least_multiple(N, NPerBlock);
make_tuple(KBatch, K0, K1Number)), // unmerge is split 1D to 3D
make_right_pad_transform(M, MPad - M)), //
make_tuple(Sequence<1>{}, Sequence<0>{}), // mapped to input M & K; sequence 0 is M;
// 1 is K; make unmerge is working on K;
make_tuple(Sequence<0, 1, 3>{}, // input is M,K; output we want is Kbatch, K0 and K1
// -> 0, 1, 3; output is transformed from 2D to 4D
Sequence<2>{})); // 2->M
} }
else
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{ {
return transform_tensor_descriptor( // k_batch * k0 * k0_per_block * k1
a_grid_desc_m_k, auto K_t = K_Batch * K0PerBlock * K1;
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), return (K + K_t - 1) / K_t * K0PerBlock;
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1( __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0) {
{ auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
const auto b_grid_desc_k_n = [&]() { static constexpr auto K1Number = Number<K1>{};
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{
const auto a_grid_desc_m_k = [&]()
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); if constexpr (is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(
make_tuple(KBatch, K0, K1Number)), // unmerge is split 1D to 3D
make_right_pad_transform(M, MPad - M)), //
make_tuple(Sequence<1>{}, Sequence<0>{}), // mapped to input M & K; sequence 0 is M;
// 1 is K; make unmerge is working on K;
make_tuple(Sequence<0, 1, 3>{}, // input is M,K; output we want is Kbatch, K0 and K1
// -> 0, 1, 3; output is transformed from 2D to 4D
Sequence<2>{})); // 2->M
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
}(); }
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{ {
return transform_tensor_descriptor( const auto b_grid_desc_k_n = [&]()
b_grid_desc_k_n, {
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
make_right_pad_transform(N, NPad - N)), {
make_tuple(Sequence<0>{}, Sequence<1>{}), return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); }
} else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
else {
{ return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
return transform_tensor_descriptor( }
b_grid_desc_k_n, }();
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)), if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
make_tuple(Sequence<0>{}, Sequence<1>{}), {
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
} }
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]()
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) {
if constexpr (is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
}(); }
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_Kbatch_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1, 1));
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) __host__ __device__ static constexpr bool CheckValidity(const Argument &karg)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( // const auto MPadded = CalculateMPadded(karg.M);
c_grid_desc_m_n, // const auto NPadded = CalculateNPadded(karg.N);
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
make_tuple(Sequence<0>{}, Sequence<1>{}), karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
make_tuple(Sequence<0>{}, Sequence<1>{})); const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto KBatch_b = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto N_ = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
return (M_ % MPerBlock == 0 && N_ % NPerBlock == 0 && K0_ % K0PerBlock == 0 &&
M_ == c_grid_desc_m_n.GetLength(I0) && N_ == c_grid_desc_m_n.GetLength(I1) &&
a_grid_desc_kbatch_k0_m_k1.GetLength(I3) ==
b_grid_desc_kbatch_k0_n_k1.GetLength(I3) &&
karg.k_batch >= 1 && KBatch_a == karg.k_batch && KBatch_b == karg.k_batch);
} }
else
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
const AGridDesc_Kbatch_K0_M_K1 &a_grid_desc_kbatch_k0_m_k1)
{
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_kbatch_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), // IP
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); // OP
return a_grid_desc_kbatch_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(
const BGridDesc_Kbatch_K0_N_K1 &b_grid_desc_kbatch_k0_n_k1)
{
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const auto N = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_kbatch_k0_n0_n1_k1 = transform_tensor_descriptor(
b_grid_desc_kbatch_k0_n_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_kbatch_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
return transform_tensor_descriptor( constexpr auto M11 = Number<MPerThread>{};
constexpr auto N11 = Number<NPerThread>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
}
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch) // return block_id to C matrix tile idx (m0, n0) mapping
{ __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; {
const index_t KPad = KBatch * K0 * K1; //: 3d ksplit for C
return KPad; return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
} }
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; //
using AGridDesc_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(AGridDesc_Kbatch_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(BGridDesc_Kbatch_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); //
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap()); //
template <bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K,
index_t K0, index_t k_batch, index_t StrideA, index_t StrideB, index_t StrideC, index_t MPadded, index_t NPadded, const Block2CTileMap &block_2_ctile_map)
{
using AGridDesc_Kbatch_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1, 1)); constexpr index_t shared_block_size =
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1)); GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
__shared__ FloatAB p_shared_block[shared_block_size];
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
M, MPadded, K, StrideA, k_batch, K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
K, NPadded, N, StrideB, k_batch, K0); //
const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
const auto b_grid_desc_kbatch_k0_n0_n1_k1 =
GridwiseTsmm::MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1); //
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseTsmm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
get_block_1d_id(), N, k_batch);
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if (!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) // TODO: change this. I think it needs multi-dimensional alignment
{ constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_copy_kbatch_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, //: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, // 0, 1, 2, 3, 4
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_kbatch_k0_m0_m1_k1)>, // Global tensor desc
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_kbatch_k0_m0_m1_k1, // for src desc
make_multi_index(kbatch_id, 0, im0, 0, 0), //: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1, // for dst desc
make_multi_index(0, 0, 0, 0, 0));
static constexpr auto b_thread_desc_copy_kbatch_k0_n0_n1_k1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<K0PerBlock>{},
I1,
Number<NPerThread>{},
Number<K1>{})); //: this descriptor is used only for copy
static constexpr auto b_thread_desc_copy_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_kbatch_k0_n0_n1_k1)>,
decltype(b_thread_desc_copy_kbatch_k0_n0_n1_k1), //
Sequence<1, K0PerBlock, 1, NPerThread, K1.value>,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1,
false,
true>(b_grid_desc_kbatch_k0_n0_n1_k1,
make_multi_index(kbatch_id, 0, in0, get_thread_local_1d_id() * NPerThread, 0));
const auto MPadded = CalculateMPadded(karg.M); static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
const auto NPadded = CalculateNPadded(karg.N); make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto KBatch_b = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M_ = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto N_ = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
return (M_ % MPerBlock == 0 && N_ % NPerBlock == 0 && K0_ % K0PerBlock == 0 &&
M_ == c_grid_desc_m_n.GetLength(I0) && N_ == c_grid_desc_m_n.GetLength(I1) &&
a_grid_desc_kbatch_k0_m_k1.GetLength(I3) ==
b_grid_desc_kbatch_k0_n_k1.GetLength(I3) &&
karg.k_batch >= 1 && KBatch_a == karg.k_batch && KBatch_b == karg.k_batch);
}
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1 // TODO: check alignment
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1( // A matrix in LDS memory, dst of blockwise copy
const AGridDesc_Kbatch_K0_M_K1& a_grid_desc_kbatch_k0_m_k1) // be careful of LDS alignment
{ constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0); make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_kbatch_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), // IP
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); // OP
return a_grid_desc_kbatch_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1( // TODO: check alignment
const BGridDesc_Kbatch_K0_N_K1& b_grid_desc_kbatch_k0_n_k1) // A matrix in LDS memory, for blockwise GEMM
{ constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const auto N = b_grid_desc_kbatch_k0_n_k1.GetLength(I2);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_kbatch_k0_n0_n1_k1 = transform_tensor_descriptor(
b_grid_desc_kbatch_k0_n_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_kbatch_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) a_k0_m_k1_block_desc.GetElementSpaceSize() &&
{ "wrong!");
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{}; const auto blockwise_tsmm =
constexpr auto N1 = Number<NPerBlock>{}; BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerBlock,
KPerThread>{};
const auto M0 = M / M1; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
const auto N0 = N / N1; decltype(blockwise_tsmm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto M11 = Number<MPerThread>{}; constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
constexpr auto N11 = Number<NPerThread>{}; sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
constexpr auto M10 = M1 / M11; // LDS allocation for A and B: be careful of alignment
constexpr auto N10 = N1 / N11; constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( FloatAB *p_a_block_double = p_shared_block;
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11; auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
} b_k0_n_k1_thread_desc.GetElementSpaceSize());
// return block_id to C matrix tile idx (m0, n0) mapping auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() b_k0_n_k1_thread_desc.GetElementSpaceSize());
{
//: 3d ksplit for C
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; //
using AGridDesc_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(AGridDesc_Kbatch_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(BGridDesc_Kbatch_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); //
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap()); //
template <bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg)
{
constexpr index_t shared_block_size =
GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
const Block2CTileMap& block_2_ctile_map = Block2CTileMap{};
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); //
const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
const auto b_grid_desc_kbatch_k0_n0_n1_k1 =
GridwiseTsmm::MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1); //
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseTsmm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
get_block_1d_id(), karg.N, karg.k_batch);
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment // register allocation for output
constexpr auto max_lds_align = K1; auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
constexpr auto a_block_desc_copy_kbatch_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, //: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, // 0, 1, 2, 3, 4
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_kbatch_k0_m0_m1_k1)>, // Global tensor desc
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_kbatch_k0_m0_m1_k1, // for src desc
make_multi_index(kbatch_id, 0, im0, 0, 0), //: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1, // for dst desc
make_multi_index(0, 0, 0, 0, 0));
static constexpr auto b_thread_desc_copy_kbatch_k0_n0_n1_k1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<K0PerBlock>{},
I1,
Number<NPerThread>{},
Number<K1>{})); //: this descriptor is used only for copy
static constexpr auto b_thread_desc_copy_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_kbatch_k0_n0_n1_k1)>,
decltype(b_thread_desc_copy_kbatch_k0_n0_n1_k1), //
Sequence<1, K0PerBlock, 1, NPerThread, K1.value>,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1,
false,
true>(b_grid_desc_kbatch_k0_n0_n1_k1,
make_multi_index(kbatch_id, 0, in0, get_thread_local_1d_id() * NPerThread, 0));
static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
const auto blockwise_tsmm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerBlock,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_tsmm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1,
a_global_buf); // a_global_buf -> reg_tmp_buf
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1,
a_block_even_buf); // reg_tmp_buf->a_block_even_buf
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
}
if constexpr(HasMainKBlockLoop) // Initialize C
{ c_thread_buf.Clear();
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0; constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
// LDS double buffer: main body auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// use Do-While loop instead of For loop to simplify control flow p_a_block_double, a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1, auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
b_thread_slice_copy_step); p_a_block_double + a_block_aligned_space_size,
a_block_desc_copy_kbatch_k0_m0_m1_k1.GetElementSpaceSize());
// LDS doubel buffer: load next data from device mem // LDS double buffer: preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf); {
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1,
a_global_buf); // a_global_buf -> reg_tmp_buf
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1,
a_block_even_buf); // reg_tmp_buf->a_block_even_buf
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_global_buf, b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1, b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf); b_thread_even_buf);
}
block_sync_lds(); if constexpr (HasMainKBlockLoop)
{
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
// LDS double buffer: GEMM on current data index_t k_block_data_begin = 0;
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: main body
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf); // use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
// odd iteration b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1, b_thread_slice_copy_step);
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1, // LDS doubel buffer: load next data from device mem
b_thread_slice_copy_step); a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
// LDS doubel buffer: load next data from device mem b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf); b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1, block_sync_lds();
b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds(); // LDS double buffer: GEMM on current data
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: GEMM on current data // LDS double buffer: store next data to LDS
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf); a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
// LDS double buffer: store next data to LDS // odd iteration
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
k_block_data_begin += 2 * K0PerBlock; b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
} while(k_block_data_begin < K0 - 2 * K0PerBlock); b_thread_slice_copy_step);
}
// LDS double buffer: tail // LDS doubel buffer: load next data from device mem
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step); b_global_buf,
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds(); block_sync_lds();
// LDS double buffer: load last data from device mem // LDS double buffer: GEMM on current data
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf); blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1, // LDS double buffer: store next data to LDS
b_global_buf, a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf);
b_thread_desc_copy_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf);
// LDS double buffer: GEMM on 2nd-last data k_block_data_begin += 2 * K0PerBlock;
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf); } while (k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: store last data to LDS // LDS double buffer: tail
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf); if constexpr (HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step);
block_sync_lds(); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_kbatch_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS double buffer: GEMM on last data block_sync_lds();
blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: load last data from device mem
blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf); a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_global_buf);
}
// output: register to global memory b_threadwise_copy.Run(b_grid_desc_kbatch_k0_n0_n1_k1,
{ b_global_buf,
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = b_thread_desc_copy_k0_n0_n1_k1,
make_naive_tensor_descriptor_packed( make_tuple(I0, I0, I0, I0, I0),
make_tuple(I1, b_thread_odd_buf);
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{}, // LDS double buffer: GEMM on 2nd-last data
I1, blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{})); // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_odd_buf);
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_tsmm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( block_sync_lds();
get_thread_local_1d_id());
// LDS double buffer: GEMM on last data
ThreadwiseTensorSliceTransfer_v1r3< blockwise_tsmm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
FloatAcc, }
FloatC, else // if has 1 iteration left
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), {
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), __syncthreads();
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, // LDS double buffer: GEMM on last data
c_m10_m11_n10_n11_thread_tensor_lengths[I0], blockwise_tsmm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
c_m10_m11_n10_n11_thread_tensor_lengths[I1], }
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2], // output: register to global memory
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, {
CThreadTransferSrcDstAccessOrder, constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
CThreadTransferSrcDstVectorDim, make_naive_tensor_descriptor_packed(
CThreadTransferDstScalarPerVector, make_tuple(I1,
CGlobalMemoryDataOperation, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
1, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
true>{c_grid_desc_m0_m10_m11_n0_n10_n11, I1,
make_multi_index(im0, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0, const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], blockwise_tsmm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), get_thread_local_1d_id());
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11, ThreadwiseTensorSliceTransfer_v1r3<
make_tuple(I0, I0, I0, I0, I0, I0), FloatAcc,
c_thread_buf, FloatC,
c_grid_desc_m0_m10_m11_n0_n10_n11, decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
c_grid_buf); decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
} }
} };
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment