Commit 5ab76075 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Batched gemm - passed batch args

parent 533204d6
...@@ -14,6 +14,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -14,6 +14,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_B, ck_tile::index_t stride_B,
ck_tile::index_t stride_C, ck_tile::index_t stride_C,
ck_tile::index_t kbatch, ck_tile::index_t kbatch,
ck_tile::index_t batch_stride_A,
ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
...@@ -28,6 +32,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -28,6 +32,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_A = stride_A; args.stride_A = stride_A;
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
args.batch_stride_A = batch_stride_A;
args.batch_stride_B = batch_stride_B;
args.batch_stride_C = batch_stride_C;
args.batch_count = batch_count;
float ave_time = gemm_calc<ALayout, BLayout, CLayout>( float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
...@@ -63,6 +71,18 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -63,6 +71,18 @@ int run_batched_gemm_example(int argc, char* argv[])
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
std::cout << "Received args: " << std::endl;
std::cout << "batch_stride_A: " << batch_stride_A << '\n'
<< "batch_stride_B: " << batch_stride_B << '\n'
<< "batch_stride_C: " << batch_stride_C << '\n'
<< "batch_count: " << batch_count << std::endl;
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[])
stride_B, stride_B,
stride_C, stride_C,
batch_size, batch_size,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment