Commit 8bba35f2 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Refactor GemmKernel - review changes

parent b9806269
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "batched_gemm.hpp" #include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const BatchedGemmHostArgs& args, const ck_tile::stream_config& s) float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename DataType> template <typename DataType>
struct BatchedGemmTypeConfig; struct BatchedGemmTypeConfig;
...@@ -57,4 +56,4 @@ auto create_args(int argc, char* argv[]) ...@@ -57,4 +56,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float batched_gemm(BatchedGemmHostArgs args, const ck_tile::stream_config& s); float batched_gemm(ck_tile::BatchedGemmHostArgs args, const ck_tile::stream_config& s);
...@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
BatchedGemmHostArgs args; ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
......
...@@ -12,6 +12,82 @@ ...@@ -12,6 +12,82 @@
namespace ck_tile { namespace ck_tile {
struct GemmProblem
{
CK_TILE_HOST GemmProblem() = default;
CK_TILE_HOST GemmProblem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
struct GemmHostArgs : public GemmProblem
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_)
: GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t k_batch;
};
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel struct GemmKernel
{ {
...@@ -147,7 +223,7 @@ struct GemmKernel ...@@ -147,7 +223,7 @@ struct GemmKernel
CDataType* c_ptr, CDataType* c_ptr,
const GemmKernelArgs& kargs) const const GemmKernelArgs& kargs) const
{ {
auto&& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
...@@ -168,7 +244,7 @@ struct GemmKernel ...@@ -168,7 +244,7 @@ struct GemmKernel
} }
}(); }();
auto&& b_tensor_view = [&]() { const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
...@@ -189,7 +265,7 @@ struct GemmKernel ...@@ -189,7 +265,7 @@ struct GemmKernel
} }
}(); }();
auto&& c_tensor_view = [&]() { const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
...@@ -214,10 +290,10 @@ struct GemmKernel ...@@ -214,10 +290,10 @@ struct GemmKernel
} }
template <typename TensorView> template <typename TensorView>
CK_TILE_DEVICE auto MakeGemmPadViews(TensorView&& views) const CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const
{ {
auto&& a_pad_view = [&]() { const auto& a_pad_view = [&]() {
auto&& a_tensor_view = views.at(I0); const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -234,8 +310,8 @@ struct GemmKernel ...@@ -234,8 +310,8 @@ struct GemmKernel
} }
}(); }();
auto&& b_pad_view = [&]() { const auto& b_pad_view = [&]() {
auto&& b_tensor_view = views.at(I1); const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -252,8 +328,8 @@ struct GemmKernel ...@@ -252,8 +328,8 @@ struct GemmKernel
} }
}(); }();
auto&& c_pad_view = [&]() { const auto& c_pad_view = [&]() {
auto&& c_tensor_view = views.at(I2); const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -275,22 +351,22 @@ struct GemmKernel ...@@ -275,22 +351,22 @@ struct GemmKernel
template <typename PadView> template <typename PadView>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
MakeGemmTileWindows(PadView&& views, const index_t i_m, const index_t i_n) const MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const
{ {
auto&& a_pad_view = views.at(I0); const auto& a_pad_view = views.at(I0);
auto&& a_block_window = make_tile_window( const auto& a_block_window = make_tile_window(
a_pad_view, a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto&& b_pad_view = views.at(I1); const auto& b_pad_view = views.at(I1);
auto&& b_block_window = make_tile_window( const auto& b_block_window = make_tile_window(
b_pad_view, b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
auto&& c_pad_view = views.at(I2); const auto& c_pad_view = views.at(I2);
auto&& c_block_window = make_tile_window( const auto& c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
...@@ -299,15 +375,14 @@ struct GemmKernel ...@@ -299,15 +375,14 @@ struct GemmKernel
} }
/** /**
* Create tensor views, pad views, tile windows. * @brief Runs single GEMM problem cooperatively by whole workgroup.
* Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
* *
* @param a_ptr input A pointer * @param a_ptr input A pointer
* @param b_ptr input B pointer * @param b_ptr input B pointer
* @param c_ptr output C pointer * @param c_ptr output C pointer
* @param kargs GEMM kernel arguments * @param kargs GEMM kernel arguments
* @param block_idx_m M block index * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n N block index * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*/ */
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
...@@ -317,9 +392,10 @@ struct GemmKernel ...@@ -317,9 +392,10 @@ struct GemmKernel
const index_t block_idx_n) const const index_t block_idx_n) const
{ {
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
auto&& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
auto&& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto&& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const auto& gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
...@@ -327,13 +403,13 @@ struct GemmKernel ...@@ -327,13 +403,13 @@ struct GemmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
auto&& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
auto&& b_block_window = gemm_tile_windows.at(I1); const auto& b_block_window = gemm_tile_windows.at(I1);
auto&& c_block_tile = auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
// Run CShuffle or Default 2D Epilogue // Run Epilogue Pipeline
auto&& c_block_window = gemm_tile_windows.at(I2); auto c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(c_block_window, c_block_tile);
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
namespace ck_tile {
struct Problem
{
CK_TILE_HOST Problem() = default;
CK_TILE_HOST Problem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
struct GemmHostArgs : public Problem
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_)
: Problem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t k_batch;
};
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment