"include/vscode:/vscode.git/clone" did not exist on "ecdfe960921032c1aae6dc2c4a3e0ad1b8bba559"
Commit 799cde32 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK TILE] Refactor GemmKernel - review changes part1

parent ed528d76
......@@ -29,7 +29,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs
{
};
......
......@@ -7,17 +7,8 @@
namespace ck_tile {
struct BatchedGemmHostArgs
struct BatchedGemmHargs : GemmHargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
......@@ -29,7 +20,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
{
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmCommonKargs = typename Base::GemmCommonKargs;
using GemmKargs = typename Base::GemmKargs;
using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType;
......@@ -42,7 +33,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
struct BatchedGemmKargs : GemmCommonKargs
struct BatchedGemmKargs : GemmKargs
{
index_t batch_stride_A;
index_t batch_stride_B;
......@@ -51,7 +42,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
};
using Kargs = BatchedGemmKargs;
using Hargs = BatchedGemmHostArgs;
using Hargs = BatchedGemmHargs;
__host__ static constexpr auto GridSize(const Hargs& k)
{
......@@ -102,7 +93,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
this->run_common_gemm_pipeline(a_start, b_start, c_start, kargs, i_m, i_n);
this->RunGemm(a_start, b_start, c_start, kargs, i_m, i_n);
}
};
......
......@@ -12,6 +12,19 @@
namespace ck_tile {
struct GemmHargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
......@@ -25,7 +38,6 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
......@@ -39,7 +51,7 @@ struct GemmKernel
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmCommonKargs
struct GemmKargs
{
const void* a_ptr;
const void* b_ptr;
......@@ -52,17 +64,17 @@ struct GemmKernel
index_t stride_C;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
CK_TILE_HOST static constexpr GemmKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
{
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
return GemmKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......@@ -70,7 +82,7 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs)
CK_TILE_HOST static bool IsSupportedArgument(const GemmKargs& kargs)
{
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
......@@ -143,16 +155,16 @@ struct GemmKernel
return true;
}
CK_TILE_DEVICE auto make_gemm_tensor_views(const ADataType* a_start,
const BDataType* b_start,
CDataType* c_start,
const GemmCommonKargs& kargs) const
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKargs& kargs) const
{
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
a_ptr,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
......@@ -161,7 +173,7 @@ struct GemmKernel
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
a_ptr,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
......@@ -173,7 +185,7 @@ struct GemmKernel
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
b_ptr,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
......@@ -182,7 +194,7 @@ struct GemmKernel
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
b_ptr,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
......@@ -194,7 +206,7 @@ struct GemmKernel
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
......@@ -203,7 +215,7 @@ struct GemmKernel
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
......@@ -215,7 +227,7 @@ struct GemmKernel
}
template <typename TensorView>
CK_TILE_DEVICE auto make_gemm_pad_views(TensorView&& views) const
CK_TILE_DEVICE auto MakeGemmPadViews(TensorView& views) const
{
auto a_pad_view = [&]() {
auto a_tensor_view = views.at(I0);
......@@ -276,7 +288,7 @@ struct GemmKernel
template <typename PadView>
CK_TILE_DEVICE auto
make_gemm_tile_windows(PadView&& views, const index_t i_m, const index_t i_n) const
MakeGemmTileWindows(PadView& views, const index_t i_m, const index_t i_n) const
{
auto a_pad_view = views.at(I0);
auto a_block_window = make_tile_window(
......@@ -299,18 +311,18 @@ struct GemmKernel
return make_tuple(a_block_window, b_block_window, c_block_window);
}
CK_TILE_DEVICE void run_common_gemm_pipeline(const ADataType* a_start,
const BDataType* b_start,
CDataType* c_start,
const GemmCommonKargs& kargs,
const index_t i_m,
const index_t i_n) const
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKargs& kargs,
const index_t block_idx_m,
const index_t block_idx_n) const
{
// Convert pointers to tensor views
const auto gemm_tensor_views_tuple =
make_gemm_tensor_views(a_start, b_start, c_start, kargs);
const auto gemm_pad_views = make_gemm_pad_views(gemm_tensor_views_tuple);
const auto gemm_tile_windows = make_gemm_tile_windows(gemm_pad_views, i_m, i_n);
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto& gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
......@@ -329,15 +341,15 @@ struct GemmKernel
EpiloguePipeline{}(c_block_window, c_block_tile);
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
CK_TILE_DEVICE void operator()(GemmKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
run_common_gemm_pipeline(a_start, b_start, c_start, kargs, i_m, i_n);
RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
}
};
......
......@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs
{
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment