Commit 71eea17c authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Batched gemm - counting strides

parent 5ab76075
...@@ -96,7 +96,7 @@ float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_confi ...@@ -96,7 +96,7 @@ float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_confi
args.batch_stride_C, args.batch_stride_C,
args.batch_count); args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0) if(s.log_level_ > 0)
......
...@@ -73,16 +73,16 @@ auto create_args(int argc, char* argv[]) ...@@ -73,16 +73,16 @@ auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("b", "1", "batch size")
.insert("m", "3840", "m dimension") .insert("m", "256", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "128", "n dimension")
.insert("k", "4096", "k dimension") .insert("k", "128", "k dimension")
.insert("stride_a", "0", "Tensor A stride") .insert("stride_a", "128", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "128", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "128", "Tensor C stride")
.insert("batch_stride_a", "0", "Batch A stride") .insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "0", "Batch B stride") .insert("batch_stride_b", "16384", "Batch B stride")
.insert("batch_stride_c", "0", "Batch C stride") .insert("batch_stride_c", "32768", "Batch C stride")
.insert("batch_count", "1", "Batch count") .insert("batch_count", "16", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
......
...@@ -89,9 +89,12 @@ struct BatchedGemmKernel ...@@ -89,9 +89,12 @@ struct BatchedGemmKernel
CK_TILE_DEVICE void operator()(BatchedGemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmCommonKargs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
// options // const auto i_k = blockIdx.z;
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr); // options
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const ADataType* a_start = static_cast<const ADataType*>(
kargs.a_ptr); //+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A);
const BDataType* b_start = static_cast<const BDataType*>(
kargs.b_ptr); //+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
// Convert pointers to tensor views // Convert pointers to tensor views
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
...@@ -169,7 +172,8 @@ struct BatchedGemmKernel ...@@ -169,7 +172,8 @@ struct BatchedGemmKernel
auto c_block_tile = auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_start = static_cast<CDataType*>(
kargs.c_ptr); //; + __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_C);
auto c_tensor_view = [&]() { auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment