Commit f88c2f86 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

kernarg load latency optimization for mi300

parent c2784145
...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp) add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp)
add_dependencies(example_gemv_splitk add_dependencies(example_gemv_splitk
example_gemv_splitk_fp16) example_gemv_splitk_fp16)
set_source_files_properties(gemv_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS "-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp) add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
add_dependencies(example_tall_and_skinny_gemm_splitk add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16) example_tall_and_skinny_gemm_splitk_fp16)
set_source_files_properties(tall_and_skinny_gemm_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS "-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#ifndef KERNARG_PRELOAD
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config, float launch_and_time_kernel(const StreamConfig &stream_config,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
...@@ -18,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -18,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if (stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -48,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -48,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for (int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
...@@ -78,8 +79,83 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -78,8 +79,83 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif #endif
} }
#else
template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig &stream_config,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
// Args* args1;
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
if (stream_config.time_kernel_)
{
#if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
//
// warm up
const int nrepeat = 1000;
for (auto i = 0; i < nrepeat; i++)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_,
args...);
hip_check_error(hipGetLastError());
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
float total_time = 0;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_,
args...);
// hip_check_error(hipGetLastError());
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat;
}
else
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(
args...);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
}
#endif
template <typename... Args, typename F, typename PreProcessFunc> template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
PreProcessFunc preprocess, PreProcessFunc preprocess,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
...@@ -88,7 +164,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -88,7 +164,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if (stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -119,7 +195,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -119,7 +195,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for (int i = 0; i < nrepeat; ++i)
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
......
...@@ -16,11 +16,14 @@ ...@@ -16,11 +16,14 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck
namespace tensor_operation { {
namespace device { namespace tensor_operation
{
namespace device
{
template < template <
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
...@@ -58,7 +61,7 @@ template < ...@@ -58,7 +61,7 @@ template <
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout, struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -68,7 +71,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -68,7 +71,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -113,11 +116,11 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -113,11 +116,11 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument &karg, const StreamConfig &stream_config = StreamConfig{})
{ {
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch); const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
// const auto b2c_map = DefaultBlock2CTileMap{}; const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0; const auto K0 = karg.K0;
...@@ -127,128 +130,144 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -127,128 +130,144 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
float ave_time = 0; float ave_time = 0;
if(karg.k_batch > 1) if (karg.k_batch > 1)
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if(has_main_k_block_loop && has_double_tail_k_block_loop) if (has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
if(karg.k_batch == 1) if (karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
true, true,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
true, true,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if (has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
if(karg.k_batch == 1) if (karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
true, true,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
true, true,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if (!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
if(karg.k_batch == 1) if (karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
false, false,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
false, false,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
} }
else else
{ {
if(karg.k_batch == 1) if (karg.k_batch == 1)
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
false, false,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
false, false,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, (karg.M), (karg.N), (karg.K),
(karg.K0), (karg.k_batch), karg.MPadded, karg.NPadded, b2c_map);
} }
} }
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, float
const StreamConfig& stream_config = StreamConfig{}) override Run(const BaseArgument *p_arg,
const StreamConfig &stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument *>(p_arg), stream_config);
} }
}; };
...@@ -258,12 +277,12 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -258,12 +277,12 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
return true; return true;
} }
// // // //
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument &arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if (ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{ {
return GridwiseTsmm::CheckValidity(arg); return GridwiseTsmm::CheckValidity(arg);
} }
...@@ -274,14 +293,14 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -274,14 +293,14 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
} }
// // // //
// polymorphic // polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument *p_arg) override
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument *>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType *p_a,
const BDataType* p_b, const BDataType *p_b,
CDataType* p_c, CDataType *p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -302,8 +321,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -302,8 +321,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
// GridwiseTsmm::CalculateMPadded(M), GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N), GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch), // GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // // KBatch}; // //
...@@ -312,9 +331,9 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -312,9 +331,9 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void *p_a,
const void* p_b, const void *p_b,
void* p_c, void *p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -327,17 +346,17 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -327,17 +346,17 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
ck::index_t KBatch = 1) override // // ck::index_t KBatch = 1) override // //
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType *>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType *>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType *>(p_c),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
// GridwiseTsmm::CalculateMPadded(M), GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N), GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch), // GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // // KBatch); // //
...@@ -370,8 +389,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -370,8 +389,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
return str.str(); return str.str();
} }
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -16,31 +16,46 @@ ...@@ -16,31 +16,46 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck
{
template <typename GridwiseTsmm, template <typename GridwiseTsmm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename BLayout,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop,
typename Block2CTileMap> typename Block2CTileMap>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_tsmm_dl_v1r3( kernel_tsmm_dl_v1r3(
typename GridwiseTsmm::Argument karg) //: in __global__ functions, struct is const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K,
index_t K0, index_t k_batch, index_t MPadded, index_t NPadded, const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is
// better for reduced load overhead // better for reduced load overhead
{ {
// strides depend on B's layout
if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K,
K0, k_batch, K, N, N, MPadded, NPadded, block_2_ctile_map);
}
else
{
GridwiseTsmm::template Run<HasMainKBlockLoop, GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop, HasDoubleTailKBlockLoop,
GridwiseTsmm, GridwiseTsmm,
CGlobalMemoryDataOperation>(karg); CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K,
} K0, k_batch, K, K, N, MPadded, NPadded, block_2_ctile_map);
}
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
...@@ -68,8 +83,8 @@ template <index_t BlockSize, ...@@ -68,8 +83,8 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn struct GridwiseTsmmDl_km_kn_mn
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -81,17 +96,17 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -81,17 +96,17 @@ struct GridwiseTsmmDl_km_kn_mn
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument // struct Argument : public tensor_operation::device::BaseArgument //
{ {
Argument(const FloatAB* p_a_grid_, Argument(const FloatAB *p_a_grid_,
const FloatAB* p_b_grid_, const FloatAB *p_b_grid_,
FloatC* p_c_grid_, FloatC *p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
// index_t MPadded_, index_t MPadded_,
// index_t NPadded_, index_t NPadded_,
// index_t KPadded_, // index_t KPadded_,
index_t K0_, index_t K0_,
index_t k_batch_) index_t k_batch_)
...@@ -104,8 +119,8 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -104,8 +119,8 @@ struct GridwiseTsmmDl_km_kn_mn
StrideA{StrideA_}, StrideA{StrideA_},
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
// MPadded(MPadded_), MPadded(MPadded_),
// NPadded(NPadded_), NPadded(NPadded_),
// KPadded(KPadded_), // KPadded(KPadded_),
K0(K0_), K0(K0_),
k_batch(k_batch_) k_batch(k_batch_)
...@@ -113,15 +128,15 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -113,15 +128,15 @@ struct GridwiseTsmmDl_km_kn_mn
} }
// private: // private:
const FloatAB* p_a_grid; const FloatAB *p_a_grid;
const FloatAB* p_b_grid; const FloatAB *p_b_grid;
FloatC* p_c_grid; FloatC *p_c_grid;
index_t M, N, K; index_t M, N, K;
index_t StrideA, StrideB, StrideC; index_t StrideA, StrideB, StrideC;
//: //:
// index_t MPadded; index_t MPadded;
// index_t NPadded; index_t NPadded;
// index_t KPadded; // index_t KPadded;
index_t K0; index_t K0;
index_t k_batch; index_t k_batch;
...@@ -199,18 +214,19 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -199,18 +214,19 @@ struct GridwiseTsmmDl_km_kn_mn
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0) index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]()
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
if constexpr (is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -239,18 +255,19 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -239,18 +255,19 @@ struct GridwiseTsmmDl_km_kn_mn
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0) index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]()
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) {
if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -273,18 +290,19 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -273,18 +290,19 @@ struct GridwiseTsmmDl_km_kn_mn
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]()
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) {
if constexpr (is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -317,15 +335,15 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -317,15 +335,15 @@ struct GridwiseTsmmDl_km_kn_mn
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1)); using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) __host__ __device__ static constexpr bool CheckValidity(const Argument &karg)
{ {
const auto MPadded = CalculateMPadded(karg.M); // const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N); // const auto NPadded = CalculateNPadded(karg.N);
const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0); const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
...@@ -343,7 +361,7 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -343,7 +361,7 @@ struct GridwiseTsmmDl_km_kn_mn
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1 // KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1( __host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
const AGridDesc_Kbatch_K0_M_K1& a_grid_desc_kbatch_k0_m_k1) const AGridDesc_Kbatch_K0_M_K1 &a_grid_desc_kbatch_k0_m_k1)
{ {
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0); const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
...@@ -365,7 +383,7 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -365,7 +383,7 @@ struct GridwiseTsmmDl_km_kn_mn
} }
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1( __host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(
const BGridDesc_Kbatch_K0_N_K1& b_grid_desc_kbatch_k0_n_k1) const BGridDesc_Kbatch_K0_N_K1 &b_grid_desc_kbatch_k0_n_k1)
{ {
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0); const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
...@@ -387,7 +405,7 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -387,7 +405,7 @@ struct GridwiseTsmmDl_km_kn_mn
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -433,27 +451,21 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -433,27 +451,21 @@ struct GridwiseTsmmDl_km_kn_mn
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm, typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg) __device__ static void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K,
index_t K0, index_t k_batch, index_t StrideA, index_t StrideB, index_t StrideC, index_t MPadded, index_t NPadded, const Block2CTileMap &block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
const Block2CTileMap& block_2_ctile_map = Block2CTileMap{};
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); // M, MPadded, K, StrideA, k_batch, K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); // K, NPadded, N, StrideB, k_batch, K0); //
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); // GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
...@@ -471,14 +483,14 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -471,14 +483,14 @@ struct GridwiseTsmmDl_km_kn_mn
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple( const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
get_block_1d_id(), karg.N, karg.k_batch); get_block_1d_id(), N, k_batch);
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]); const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if(!block_2_ctile_map.ValidCTileIndex( if (!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0), make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
...@@ -581,7 +593,7 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -581,7 +593,7 @@ struct GridwiseTsmmDl_km_kn_mn
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB *p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize()); b_k0_n_k1_thread_desc.GetElementSpaceSize());
...@@ -620,9 +632,9 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -620,9 +632,9 @@ struct GridwiseTsmmDl_km_kn_mn
b_thread_even_buf); b_thread_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr (HasMainKBlockLoop)
{ {
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1); // const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -679,11 +691,11 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -679,11 +691,11 @@ struct GridwiseTsmmDl_km_kn_mn
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin += 2 * K0PerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock); } while (k_block_data_begin < K0 - 2 * K0PerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr (HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
...@@ -768,5 +780,5 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -768,5 +780,5 @@ struct GridwiseTsmmDl_km_kn_mn
c_grid_buf); c_grid_buf);
} }
} }
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment