Commit c2784145 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

Revert "kernarg load latency optimization for mi300"

This reverts commit 8861bd66.
parent 8861bd66
...@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp) add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp)
add_dependencies(example_gemv_splitk add_dependencies(example_gemv_splitk
example_gemv_splitk_fp16) example_gemv_splitk_fp16)
set_source_files_properties(gemv_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS "-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
...@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp) add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
add_dependencies(example_tall_and_skinny_gemm_splitk add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16) example_tall_and_skinny_gemm_splitk_fp16)
set_source_files_properties(tall_and_skinny_gemm_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS "-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...@@ -72,9 +72,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -72,9 +72,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemv = DeviceGemvInstance{}; auto tsmm = DeviceTSMMInstance{};
auto invoker = gemv.MakeInvoker(); auto invoker = tsmm.MakeInvoker();
auto argument = gemv.MakeArgument( auto argument = tsmm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
...@@ -96,24 +96,22 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -96,24 +96,22 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
k_batch); // // k_batch); // //
// // // //
if(!gemv.IsSupportedArgument(argument)) if(!tsmm.IsSupportedArgument(argument))
{ {
std::cerr << gemv.GetTypeString() << " does not support this problem" << std::endl; std::cerr << tsmm.GetTypeString() << " does not support this problem" << std::endl;
return true; return true;
} }
c_m_n_device_buf.SetZero(); c_m_n_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
if(config.do_verification) if(config.do_verification)
{ {
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
auto ref_tsmm = ReferenceGemmInstance{};
auto ref_invoker = ref_tsmm.MakeInvoker();
auto ref_gemv = ReferenceGemmInstance{}; auto ref_argument = ref_tsmm.MakeArgument(
auto ref_invoker = ref_gemv.MakeInvoker();
auto ref_argument = ref_gemv.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -143,7 +141,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -143,7 +141,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl; << tsmm.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
...@@ -152,7 +150,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -152,7 +150,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
#endif #endif
} }
bool run_gemv_example(int argc, char* argv[]) bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
{ {
ProblemSize problem_size; ProblemSize problem_size;
ExecutionConfig config; ExecutionConfig config;
...@@ -192,5 +190,5 @@ bool run_gemv_example(int argc, char* argv[]) ...@@ -192,5 +190,5 @@ bool run_gemv_example(int argc, char* argv[])
exit(0); exit(0);
} }
return run_gemv(problem_size, config); return run_tall_and_skinny_gemm(problem_size, config);
} }
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#ifndef KERNARG_PRELOAD
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config, float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel, F kernel,
...@@ -79,80 +78,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -79,80 +78,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif #endif
} }
#else
template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
// Args* args1;
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
#if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
//
// warm up
const int nrepeat = 1000;
for(auto i = 0; i < nrepeat; i++)
hipLaunchKernelGGL(
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
hip_check_error(hipGetLastError());
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
float total_time = 0;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i)
hipLaunchKernelGGL(
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
// hip_check_error(hipGetLastError());
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat;
}
else
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
}
#endif
template <typename... Args, typename F, typename PreProcessFunc> template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess, PreProcessFunc preprocess,
......
...@@ -117,7 +117,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -117,7 +117,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
{ {
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch); const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto b2c_map = DefaultBlock2CTileMap{}; // const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0; const auto K0 = karg.K0;
...@@ -138,54 +138,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -138,54 +138,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
true, true,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
true, true,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
...@@ -196,54 +166,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -196,54 +166,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
true, true,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
true, true,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
...@@ -253,54 +193,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -253,54 +193,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
false, false,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
false, false,
true, true,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
} }
else else
...@@ -310,59 +220,30 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -310,59 +220,30 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
false, false,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
else else
{ {
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
BLayout,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
false, false,
false, false,
DefaultBlock2CTileMap>; // // DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
dim3(grid_size),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
(karg.M),
(karg.N),
(karg.K),
(karg.K0),
(karg.k_batch),
karg.MPadded,
karg.NPadded,
b2c_map);
} }
} }
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
...@@ -382,8 +263,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -382,8 +263,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1102")
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{ {
return GridwiseTsmm::CheckValidity(arg); return GridwiseTsmm::CheckValidity(arg);
} }
...@@ -422,8 +302,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -422,8 +302,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
GridwiseTsmm::CalculateMPadded(M), // GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N), // GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch), // GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // // KBatch}; // //
...@@ -456,8 +336,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout, ...@@ -456,8 +336,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
GridwiseTsmm::CalculateMPadded(M), // GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N), // GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch), // GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // // KBatch); // //
......
...@@ -21,70 +21,23 @@ namespace ck { ...@@ -21,70 +21,23 @@ namespace ck {
template <typename GridwiseTsmm, template <typename GridwiseTsmm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename BLayout,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop,
typename Block2CTileMap> typename Block2CTileMap>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_tsmm_dl_v1r3( kernel_tsmm_dl_v1r3(
const FloatAB* p_a_grid, typename GridwiseTsmm::Argument karg) //: in __global__ functions, struct is
const FloatAB* p_b_grid, // better for reduced load overhead
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t K0,
index_t k_batch,
index_t MPadded,
index_t NPadded,
const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is
// better for reduced load overhead
{ {
// strides depend on B's layout
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) GridwiseTsmm::template Run<HasMainKBlockLoop,
{ HasDoubleTailKBlockLoop,
GridwiseTsmm::template Run<HasMainKBlockLoop, GridwiseTsmm,
HasDoubleTailKBlockLoop, CGlobalMemoryDataOperation>(karg);
GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid,
p_b_grid,
p_c_grid,
M,
N,
K,
K0,
k_batch,
K,
N,
N,
MPadded,
NPadded,
block_2_ctile_map);
}
else
{
GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop,
GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid,
p_b_grid,
p_c_grid,
M,
N,
K,
K0,
k_batch,
K,
K,
N,
MPadded,
NPadded,
block_2_ctile_map);
}
} }
template <index_t BlockSize, template <index_t BlockSize,
...@@ -137,8 +90,8 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -137,8 +90,8 @@ struct GridwiseTsmmDl_km_kn_mn
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
index_t MPadded_, // index_t MPadded_,
index_t NPadded_, // index_t NPadded_,
// index_t KPadded_, // index_t KPadded_,
index_t K0_, index_t K0_,
index_t k_batch_) index_t k_batch_)
...@@ -151,8 +104,8 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -151,8 +104,8 @@ struct GridwiseTsmmDl_km_kn_mn
StrideA{StrideA_}, StrideA{StrideA_},
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
MPadded(MPadded_), // MPadded(MPadded_),
NPadded(NPadded_), // NPadded(NPadded_),
// KPadded(KPadded_), // KPadded(KPadded_),
K0(K0_), K0(K0_),
k_batch(k_batch_) k_batch(k_batch_)
...@@ -167,8 +120,8 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -167,8 +120,8 @@ struct GridwiseTsmmDl_km_kn_mn
index_t M, N, K; index_t M, N, K;
index_t StrideA, StrideB, StrideC; index_t StrideA, StrideB, StrideC;
//: //:
index_t MPadded; // index_t MPadded;
index_t NPadded; // index_t NPadded;
// index_t KPadded; // index_t KPadded;
index_t K0; index_t K0;
index_t k_batch; index_t k_batch;
...@@ -367,12 +320,12 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -367,12 +320,12 @@ struct GridwiseTsmmDl_km_kn_mn
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{ {
// const auto MPadded = CalculateMPadded(karg.M); const auto MPadded = CalculateMPadded(karg.M);
// const auto NPadded = CalculateNPadded(karg.N); const auto NPadded = CalculateNPadded(karg.N);
const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_grid_desc_kbatch_k0_m_k1 = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_grid_desc_kbatch_k0_n_k1 = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0); const auto KBatch_a = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
...@@ -480,32 +433,27 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -480,32 +433,27 @@ struct GridwiseTsmmDl_km_kn_mn
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm, typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const FloatAB* p_a_grid, __device__ static void Run(const Argument& karg)
const FloatAB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t K0,
index_t k_batch,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t MPadded,
index_t NPadded,
const Block2CTileMap& block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseTsmm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
const Block2CTileMap& block_2_ctile_map = Block2CTileMap{};
const auto MPadded = CalculateMPadded(karg.M);
const auto NPadded = CalculateNPadded(karg.N);
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_grid_desc_kbatch_k0_m_k1 = GridwiseTsmm::MakeAGridDescriptor_KBatch_K0_M_K1(
M, MPadded, K, StrideA, k_batch, K0); // karg.M, MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
K, NPadded, N, StrideB, k_batch, K0); // karg.K, NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0); //
const auto c_grid_desc_m_n = GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC); const auto c_grid_desc_m_n =
GridwiseTsmm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); // GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
...@@ -522,8 +470,8 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -522,8 +470,8 @@ struct GridwiseTsmmDl_km_kn_mn
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(
block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(get_block_1d_id(), N, k_batch); get_block_1d_id(), karg.N, karg.k_batch);
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
...@@ -559,8 +507,8 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -559,8 +507,8 @@ struct GridwiseTsmmDl_km_kn_mn
decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc decltype(a_block_desc_copy_kbatch_k0_m0_m1_k1), // block tensor desc
ABlockTransferSrcAccessOrder, // 5-dim ABlockTransferSrcAccessOrder, // 5-dim
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false, false,
...@@ -661,7 +609,7 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -661,7 +609,7 @@ struct GridwiseTsmmDl_km_kn_mn
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1, a_blockwise_copy.RunRead(a_grid_desc_kbatch_k0_m0_m1_k1,
a_global_buf); // a_global_buf -> reg_tmp_buf a_global_buf); // a_global_buf -> reg_tmp_buf
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1,
a_block_even_buf); // reg_tmp_buf->a_block_even_buf a_block_even_buf); // reg_tmp_buf->a_block_even_buf
...@@ -674,7 +622,7 @@ struct GridwiseTsmmDl_km_kn_mn ...@@ -674,7 +622,7 @@ struct GridwiseTsmmDl_km_kn_mn
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1); const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment