Commit 84213e27 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

modified for correctness pr#881

parent 1c03a65d
# Instructions for ```example_gemm_xdl```
# Instructions for ```example_gemv_splitk```
## Run ```example_gemm_xdl```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitK batches
./bin/example_gemm_xdl 0 1 5 151
#arg4: number of splitk batches
./bin/example_gemv_splitk 0 1 5 151
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
a_m_k: dim 2, lengths {1, 4608}, strides {4608, 1}
b_k_n: dim 2, lengths {4608, 1104}, strides {1, 4608}
c_m_n: dim 2, lengths {1, 1104}, strides {1104, 1}
arg.a_grid_desc_kbatch_k0_m_k1_{1,4, 1, 8}
arg.b_grid_desc_kbatch_k0_n_k1_{1,4, 1104, 8}
arg.c_grid_desc_m_n_{ 1, 1104}
launch_and_time_kernel: grid_dim {1359, 1, 1}, block_dim {64, 1, 1}
Warm up
Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
Start running 10 times...
Perf: 0.0191358 ms, 0.531698 TFlops,532.295 GB/s
```
......@@ -55,25 +55,28 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
{
// use default case
}
else if(argc == 4)
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 10)
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
......
......@@ -103,12 +103,13 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
return true;
}
c_m_n_device_buf.Zero();
c_m_n_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
if(config.do_verification)
{
auto ref_gemv = ReferenceGemmInstance{};
auto ref_invoker = ref_gemv.MakeInvoker();
......@@ -124,11 +125,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
......@@ -146,13 +145,16 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl;
return true;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_gemv_example(int argc, char* argv[])
{
ProblemSize problem_size;
// problem_size.M = 1;
ExecutionConfig config;
if(argc == 1)
{
......@@ -185,7 +187,7 @@ bool run_gemv_example(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: KBatch\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment