Commit f3e5a74e authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Gemm Kernel Refactor part1

parent feb9a2bd
...@@ -28,6 +28,10 @@ struct GemmKernel ...@@ -28,6 +28,10 @@ struct GemmKernel
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; // using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{ {
return TilePartitioner::GridSize(M, N, KBatch); return TilePartitioner::GridSize(M, N, KBatch);
...@@ -139,13 +143,11 @@ struct GemmKernel ...@@ -139,13 +143,11 @@ struct GemmKernel
return true; return true;
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const CK_TILE_DEVICE auto make_gemm_tensor_views(const ADataType* a_start,
const BDataType* b_start,
CDataType* c_start,
const GemmCommonKargs& kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -188,7 +190,43 @@ struct GemmKernel ...@@ -188,7 +190,43 @@ struct GemmKernel
} }
}(); }();
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
// Convert pointers to tensor views
const auto gemm_tensor_views_tuple =
make_gemm_tensor_views(a_start, b_start, c_start, kargs);
auto a_pad_view = [&]() { auto a_pad_view = [&]() {
auto a_tensor_view = gemm_tensor_views_tuple.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -212,6 +250,7 @@ struct GemmKernel ...@@ -212,6 +250,7 @@ struct GemmKernel
{i_m, 0}); {i_m, 0});
auto b_pad_view = [&]() { auto b_pad_view = [&]() {
auto b_tensor_view = gemm_tensor_views_tuple.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -242,29 +281,8 @@ struct GemmKernel ...@@ -242,29 +281,8 @@ struct GemmKernel
auto c_block_tile = auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() { auto c_pad_view = [&]() {
auto c_tensor_view = gemm_tensor_views_tuple.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment