Commit 8b914b24 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Gemm Kernel Refactor common gemm pipeline part2

parent f3e5a74e
......@@ -214,19 +214,11 @@ struct GemmKernel
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
template <typename TensorView>
CK_TILE_DEVICE auto make_gemm_pad_views(TensorView&& views) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
// Convert pointers to tensor views
const auto gemm_tensor_views_tuple =
make_gemm_tensor_views(a_start, b_start, c_start, kargs);
auto a_pad_view = [&]() {
auto a_tensor_view = gemm_tensor_views_tuple.at(I0);
auto a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -242,15 +234,9 @@ struct GemmKernel
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() {
auto b_tensor_view = gemm_tensor_views_tuple.at(I1);
auto b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
......@@ -267,22 +253,8 @@ struct GemmKernel
}
}();
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
auto c_pad_view = [&]() {
auto c_tensor_view = gemm_tensor_views_tuple.at(I2);
auto c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -298,12 +270,74 @@ struct GemmKernel
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE auto
make_gemm_tile_windows(PadView&& views, const index_t i_m, const index_t i_n) const
{
auto a_pad_view = views.at(I0);
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = views.at(I1);
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
auto c_pad_view = views.at(I2);
const auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
return make_tuple(a_block_window, b_block_window, c_block_window);
}
CK_TILE_DEVICE void run_common_gemm_pipeline(const ADataType* a_start,
const BDataType* b_start,
CDataType* c_start,
const GemmCommonKargs& kargs,
const index_t i_m,
const index_t i_n) const
{
// Convert pointers to tensor views
const auto gemm_tensor_views_tuple =
make_gemm_tensor_views(a_start, b_start, c_start, kargs);
const auto gemm_pad_views = make_gemm_pad_views(gemm_tensor_views_tuple);
const auto gemm_tile_windows = make_gemm_tile_windows(gemm_pad_views, i_m, i_n);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto a_block_window = gemm_tile_windows.at(I0);
auto b_block_window = gemm_tile_windows.at(I1);
// Run GEMM cooperatively by whole workgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
auto c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile);
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
run_common_gemm_pipeline(a_start, b_start, c_start, kargs, i_m, i_n);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment