Commit 71eea17c authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Batched gemm - counting strides

parent 5ab76075
......@@ -96,7 +96,7 @@ float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_confi
args.batch_stride_C,
args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
......
......@@ -73,16 +73,16 @@ auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("batch_stride_a", "0", "Batch A stride")
.insert("batch_stride_b", "0", "Batch B stride")
.insert("batch_stride_c", "0", "Batch C stride")
.insert("batch_count", "1", "Batch count")
.insert("m", "256", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "128", "k dimension")
.insert("stride_a", "128", "Tensor A stride")
.insert("stride_b", "128", "Tensor B stride")
.insert("stride_c", "128", "Tensor C stride")
.insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride")
.insert("batch_stride_c", "32768", "Batch C stride")
.insert("batch_count", "16", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
......
......@@ -89,9 +89,12 @@ struct BatchedGemmKernel
CK_TILE_DEVICE void operator()(BatchedGemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// const auto i_k = blockIdx.z;
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
const ADataType* a_start = static_cast<const ADataType*>(
kargs.a_ptr); //+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A);
const BDataType* b_start = static_cast<const BDataType*>(
kargs.b_ptr); //+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
......@@ -169,7 +172,8 @@ struct BatchedGemmKernel
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
CDataType* c_start = static_cast<CDataType*>(
kargs.c_ptr); //; + __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_C);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment