Commit b85e1128 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Refactor GemmKernel - constness fixes

parent c1f51cd4
...@@ -79,15 +79,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -79,15 +79,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args.a_ptr, auto kargs = Kernel::MakeKernelArgs(args);
args.b_ptr,
args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
......
...@@ -75,4 +75,4 @@ auto create_args(int argc, char* argv[]) ...@@ -75,4 +75,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float gemm_calc(ck_tile::GemmHostArgs args, const ck_tile::stream_config& s); float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
...@@ -79,19 +79,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -79,19 +79,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args.a_ptr, auto kargs = Kernel::MakeKernelArgs(args);
args.b_ptr,
args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
......
...@@ -56,4 +56,4 @@ auto create_args(int argc, char* argv[]) ...@@ -56,4 +56,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float batched_gemm(ck_tile::BatchedGemmHostArgs args, const ck_tile::stream_config& s); float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s);
...@@ -7,6 +7,38 @@ ...@@ -7,6 +7,38 @@
namespace ck_tile { namespace ck_tile {
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_> struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{ {
...@@ -42,25 +74,22 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -42,25 +74,22 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const void* a_ptr, CK_TILE_HOST static constexpr BatchedGemmKernelArgs
const void* b_ptr, MakeKernelArgs(const BatchedGemmHostArgs& hostArgs)
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{ {
return BatchedGemmKernelArgs{{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}, return BatchedGemmKernelArgs{{hostArgs.a_ptr,
batch_stride_A, hostArgs.b_ptr,
batch_stride_B, hostArgs.c_ptr,
batch_stride_C, hostArgs.M,
batch_count}; hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C},
hostArgs.batch_stride_A,
hostArgs.batch_stride_B,
hostArgs.batch_stride_C,
hostArgs.batch_count};
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......
...@@ -56,38 +56,6 @@ struct GemmHostArgs : public GemmProblem ...@@ -56,38 +56,6 @@ struct GemmHostArgs : public GemmProblem
index_t k_batch; index_t k_batch;
}; };
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel struct GemmKernel
{ {
...@@ -127,18 +95,30 @@ struct GemmKernel ...@@ -127,18 +95,30 @@ struct GemmKernel
index_t stride_C; index_t stride_C;
}; };
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr, CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(GemmHostArgs& hostArgs)
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
{ {
return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; return GemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C};
} }
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
...@@ -365,8 +345,8 @@ struct GemmKernel ...@@ -365,8 +345,8 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
const auto& c_pad_view = views.at(I2); const auto& c_pad_view = views.at(I2);
const auto& c_block_window = make_tile_window( auto c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
...@@ -394,8 +374,7 @@ struct GemmKernel ...@@ -394,8 +374,7 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto& gemm_tile_windows = auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
...@@ -405,11 +384,11 @@ struct GemmKernel ...@@ -405,11 +384,11 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1); const auto& b_block_window = gemm_tile_windows.at(I1);
auto c_block_tile = const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(c_block_window, c_block_tile);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment