Commit 5ab76075 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Batched gemm - passed batch args

parent 533204d6
...@@ -14,20 +14,28 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -14,20 +14,28 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_B, ck_tile::index_t stride_B,
ck_tile::index_t stride_C, ck_tile::index_t stride_C,
ck_tile::index_t kbatch, ck_tile::index_t kbatch,
ck_tile::index_t batch_stride_A,
ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
batched_gemm_basic_args args; batched_gemm_basic_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch; args.kbatch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
args.stride_A = stride_A; args.stride_A = stride_A;
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
args.batch_stride_A = batch_stride_A;
args.batch_stride_B = batch_stride_B;
args.batch_stride_C = batch_stride_C;
args.batch_count = batch_count;
float ave_time = gemm_calc<ALayout, BLayout, CLayout>( float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
...@@ -63,8 +71,20 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -63,8 +71,20 @@ int run_batched_gemm_example(int argc, char* argv[])
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t batch_size = arg_parser.get_int("b");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
std::cout << "Received args: " << std::endl;
std::cout << "batch_stride_A: " << batch_stride_A << '\n'
<< "batch_stride_B: " << batch_stride_B << '\n'
<< "batch_stride_C: " << batch_stride_C << '\n'
<< "batch_count: " << batch_count << std::endl;
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
...@@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[])
stride_B, stride_B,
stride_C, stride_C,
batch_size, batch_size,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment