Commit 799cde32 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK TILE] Refactor GemmKernel - review changes part1

parent ed528d76
...@@ -29,7 +29,7 @@ using BDataType = Types::BDataType; ...@@ -29,7 +29,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs
{ {
}; };
......
...@@ -7,17 +7,8 @@ ...@@ -7,17 +7,8 @@
namespace ck_tile { namespace ck_tile {
struct BatchedGemmHostArgs struct BatchedGemmHargs : GemmHargs
{ {
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A; index_t batch_stride_A;
index_t batch_stride_B; index_t batch_stride_B;
index_t batch_stride_C; index_t batch_stride_C;
...@@ -29,7 +20,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -29,7 +20,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
{ {
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>; using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmCommonKargs = typename Base::GemmCommonKargs; using GemmKargs = typename Base::GemmKargs;
using ADataType = typename Base::ADataType; using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType; using BDataType = typename Base::BDataType;
...@@ -42,7 +33,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -42,7 +33,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout; using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout; using CLayout = typename Base::CLayout;
struct BatchedGemmKargs : GemmCommonKargs struct BatchedGemmKargs : GemmKargs
{ {
index_t batch_stride_A; index_t batch_stride_A;
index_t batch_stride_B; index_t batch_stride_B;
...@@ -51,7 +42,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -51,7 +42,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}; };
using Kargs = BatchedGemmKargs; using Kargs = BatchedGemmKargs;
using Hargs = BatchedGemmHostArgs; using Hargs = BatchedGemmHargs;
__host__ static constexpr auto GridSize(const Hargs& k) __host__ static constexpr auto GridSize(const Hargs& k)
{ {
...@@ -102,7 +93,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -102,7 +93,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C; CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
this->run_common_gemm_pipeline(a_start, b_start, c_start, kargs, i_m, i_n); this->RunGemm(a_start, b_start, c_start, kargs, i_m, i_n);
} }
}; };
......
...@@ -12,6 +12,19 @@ ...@@ -12,6 +12,19 @@
namespace ck_tile { namespace ck_tile {
struct GemmHargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel struct GemmKernel
{ {
...@@ -25,7 +38,6 @@ struct GemmKernel ...@@ -25,7 +38,6 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>(); static constexpr auto I0 = number<0>();
...@@ -39,7 +51,7 @@ struct GemmKernel ...@@ -39,7 +51,7 @@ struct GemmKernel
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmCommonKargs struct GemmKargs
{ {
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
...@@ -52,17 +64,17 @@ struct GemmKernel ...@@ -52,17 +64,17 @@ struct GemmKernel
index_t stride_C; index_t stride_C;
}; };
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, CK_TILE_HOST static constexpr GemmKargs MakeKargs(const void* a_ptr,
const void* b_ptr, const void* b_ptr,
void* c_ptr, void* c_ptr,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
index_t stride_C) index_t stride_C)
{ {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; return GemmKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -70,7 +82,7 @@ struct GemmKernel ...@@ -70,7 +82,7 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKargs& kargs)
{ {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -143,16 +155,16 @@ struct GemmKernel ...@@ -143,16 +155,16 @@ struct GemmKernel
return true; return true;
} }
CK_TILE_DEVICE auto make_gemm_tensor_views(const ADataType* a_start, CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_start, const BDataType* b_ptr,
CDataType* c_start, CDataType* c_ptr,
const GemmCommonKargs& kargs) const const GemmKargs& kargs) const
{ {
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::VectorSizeA>{},
...@@ -161,7 +173,7 @@ struct GemmKernel ...@@ -161,7 +173,7 @@ struct GemmKernel
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A), make_tuple(1, kargs.stride_A),
number<1>{}, number<1>{},
...@@ -173,7 +185,7 @@ struct GemmKernel ...@@ -173,7 +185,7 @@ struct GemmKernel
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B), make_tuple(1, kargs.stride_B),
number<1>{}, number<1>{},
...@@ -182,7 +194,7 @@ struct GemmKernel ...@@ -182,7 +194,7 @@ struct GemmKernel
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::VectorSizeB>{},
...@@ -194,7 +206,7 @@ struct GemmKernel ...@@ -194,7 +206,7 @@ struct GemmKernel
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{}, number<GemmPipeline::VectorSizeC>{},
...@@ -203,7 +215,7 @@ struct GemmKernel ...@@ -203,7 +215,7 @@ struct GemmKernel
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C), make_tuple(1, kargs.stride_C),
number<1>{}, number<1>{},
...@@ -215,7 +227,7 @@ struct GemmKernel ...@@ -215,7 +227,7 @@ struct GemmKernel
} }
template <typename TensorView> template <typename TensorView>
CK_TILE_DEVICE auto make_gemm_pad_views(TensorView&& views) const CK_TILE_DEVICE auto MakeGemmPadViews(TensorView& views) const
{ {
auto a_pad_view = [&]() { auto a_pad_view = [&]() {
auto a_tensor_view = views.at(I0); auto a_tensor_view = views.at(I0);
...@@ -276,7 +288,7 @@ struct GemmKernel ...@@ -276,7 +288,7 @@ struct GemmKernel
template <typename PadView> template <typename PadView>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
make_gemm_tile_windows(PadView&& views, const index_t i_m, const index_t i_n) const MakeGemmTileWindows(PadView& views, const index_t i_m, const index_t i_n) const
{ {
auto a_pad_view = views.at(I0); auto a_pad_view = views.at(I0);
auto a_block_window = make_tile_window( auto a_block_window = make_tile_window(
...@@ -299,18 +311,18 @@ struct GemmKernel ...@@ -299,18 +311,18 @@ struct GemmKernel
return make_tuple(a_block_window, b_block_window, c_block_window); return make_tuple(a_block_window, b_block_window, c_block_window);
} }
CK_TILE_DEVICE void run_common_gemm_pipeline(const ADataType* a_start, CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_start, const BDataType* b_ptr,
CDataType* c_start, CDataType* c_ptr,
const GemmCommonKargs& kargs, const GemmKargs& kargs,
const index_t i_m, const index_t block_idx_m,
const index_t i_n) const const index_t block_idx_n) const
{ {
// Convert pointers to tensor views // Convert pointers to tensor views
const auto gemm_tensor_views_tuple = const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
make_gemm_tensor_views(a_start, b_start, c_start, kargs); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto gemm_pad_views = make_gemm_pad_views(gemm_tensor_views_tuple); const auto& gemm_tile_windows =
const auto gemm_tile_windows = make_gemm_tile_windows(gemm_pad_views, i_m, i_n); MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
...@@ -329,15 +341,15 @@ struct GemmKernel ...@@ -329,15 +341,15 @@ struct GemmKernel
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(c_block_window, c_block_tile);
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(GemmKargs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
// options // options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
run_common_gemm_pipeline(a_start, b_start, c_start, kargs, i_m, i_n); RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
} }
}; };
......
...@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs
{ {
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment