"vscode:/vscode.git/clone" did not exist on "2327f1a640c267743f119e59d759bc62a7887eae"
Unverified Commit af664948 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

[CK TILE] GEMM and Batched GEMM SplitK support (#1724)

* [CK TILE] Add split K support in GEMM

* Updates

* Fixes

* rebase

* fix

* Fix

* fixes

* support for batched gemm
parent 4c2eff02
...@@ -54,8 +54,7 @@ using CDataType = Types::CDataType; ...@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "3840", "m dimension")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension") .insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
...@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[]) ...@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -64,7 +64,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -64,7 +64,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#endif #endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler // Memory friendly for Interwave scheduler
...@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif #endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
...@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args);
args.p_b,
args.p_c, const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
......
...@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[]) ...@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C, ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count, ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
...@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B, batch_stride_B,
batch_stride_C, batch_stride_C,
batch_count, batch_count,
kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -56,6 +56,13 @@ struct CShuffleEpilogue ...@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed // No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return false;
}
template <typename OAccTile> template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{ {
...@@ -111,7 +118,9 @@ struct CShuffleEpilogue ...@@ -111,7 +118,9 @@ struct CShuffleEpilogue
} }
} }
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{ {
const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
...@@ -157,14 +166,28 @@ struct CShuffleEpilogue ...@@ -157,14 +166,28 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location // Store the tile data to the permuted location
if constexpr(kPadM || kPadN) if constexpr(kPadM || kPadN)
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence(); buffer_store_fence();
} }
else else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
} }
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -35,22 +35,40 @@ struct Default2DEpilogue ...@@ -35,22 +35,40 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
{ {
// TODO: this is ugly // TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN)) if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence(); buffer_store_fence();
} }
else else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
} }
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using KernelArgs = BatchedGemmKernelArgs; using KernelArgs = BatchedGemmKernelArgs;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count) __host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
{ {
return TilePartitioner::GridSize(M, N, batch_count); return TilePartitioner::GridSize(M, N, KBatch * batch_count);
} }
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
...@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
hostArgs.K, hostArgs.K,
hostArgs.stride_A, hostArgs.stride_A,
hostArgs.stride_B, hostArgs.stride_B,
hostArgs.stride_C}, hostArgs.stride_C,
hostArgs.k_batch},
hostArgs.batch_stride_A, hostArgs.batch_stride_A,
hostArgs.batch_stride_B, hostArgs.batch_stride_B,
hostArgs.batch_stride_C, hostArgs.batch_stride_C,
...@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
// options // options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A; const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
splitk_batch_offset.a_k_split_offset;
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B; const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
splitk_batch_offset.b_k_split_offset;
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C; CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); // allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1)
{
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
this->template RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
}; };
......
...@@ -93,6 +93,7 @@ struct GemmKernel ...@@ -93,6 +93,7 @@ struct GemmKernel
index_t stride_A; index_t stride_A;
index_t stride_B; index_t stride_B;
index_t stride_C; index_t stride_C;
index_t KBatch;
}; };
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
...@@ -105,28 +106,72 @@ struct GemmKernel ...@@ -105,28 +106,72 @@ struct GemmKernel
hostArgs.K, hostArgs.K,
hostArgs.stride_A, hostArgs.stride_A,
hostArgs.stride_B, hostArgs.stride_B,
hostArgs.stride_C}; hostArgs.stride_C,
} hostArgs.k_batch};
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr, }
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * kargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
{
if(kargs.KBatch != 1)
{
return false;
}
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
...@@ -198,17 +243,19 @@ struct GemmKernel ...@@ -198,17 +243,19 @@ struct GemmKernel
return true; return true;
} }
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr, template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
const GemmKernelArgs& kargs) const const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{ {
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_ptr, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::VectorSizeA>{},
number<1>{}); number<1>{});
...@@ -217,7 +264,7 @@ struct GemmKernel ...@@ -217,7 +264,7 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_ptr, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_A), make_tuple(1, kargs.stride_A),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
...@@ -229,7 +276,7 @@ struct GemmKernel ...@@ -229,7 +276,7 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_ptr, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_B), make_tuple(1, kargs.stride_B),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
...@@ -238,7 +285,7 @@ struct GemmKernel ...@@ -238,7 +285,7 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_ptr, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::VectorSizeB>{},
number<1>{}); number<1>{});
...@@ -248,7 +295,7 @@ struct GemmKernel ...@@ -248,7 +295,7 @@ struct GemmKernel
const auto& c_tensor_view = [&]() { const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
...@@ -257,7 +304,7 @@ struct GemmKernel ...@@ -257,7 +304,7 @@ struct GemmKernel
} }
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C), make_tuple(1, kargs.stride_C),
...@@ -270,7 +317,7 @@ struct GemmKernel ...@@ -270,7 +317,7 @@ struct GemmKernel
} }
template <typename TensorView> template <typename TensorView>
CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{ {
const auto& a_pad_view = [&]() { const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0); const auto& a_tensor_view = views.at(I0);
...@@ -330,8 +377,8 @@ struct GemmKernel ...@@ -330,8 +377,8 @@ struct GemmKernel
} }
template <typename PadView> template <typename PadView>
CK_TILE_DEVICE auto CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{ {
const auto& a_pad_view = views.at(I0); const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window( const auto& a_block_window = make_tile_window(
...@@ -363,23 +410,27 @@ struct GemmKernel ...@@ -363,23 +410,27 @@ struct GemmKernel
* @param kargs GEMM kernel arguments * @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/ */
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
void* smem_ptr,
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m, const index_t block_idx_m,
const index_t block_idx_n) const const index_t block_idx_n)
{ {
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
...@@ -389,18 +440,43 @@ struct GemmKernel ...@@ -389,18 +440,43 @@ struct GemmKernel
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
const SplitKBatchOffset splitk_batch_offset(kargs);
// options // options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_ptr =
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr); static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); // allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
}; };
......
...@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
...@@ -13,6 +13,8 @@ namespace ck_tile { ...@@ -13,6 +13,8 @@ namespace ck_tile {
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
static constexpr bool TransposeC = false;
#if 0 #if 0
// 2d // 2d
template <typename Problem> template <typename Problem>
...@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0; constexpr index_t smem_size = smem_size_a + smem_size_b;
smem_size += smem_size_a + smem_size_b;
return smem_size; return smem_size;
} }
...@@ -485,10 +486,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -485,10 +486,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
constexpr bool TransposeC = false;
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{}; constexpr auto I2 = number<2>{};
......
...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
...@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0) if(s.log_level_ > 0)
...@@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = 1;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
......
...@@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile:: ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>; GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment