Commit 84213e27 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

modified for correctness pr#881

parent 1c03a65d
# Instructions for ```example_gemm_xdl``` # Instructions for ```example_gemv_splitk```
## Run ```example_gemm_xdl``` ## Run ```example_gemv_splitk```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
#arg4: number of splitK batches #arg4: number of splitk batches
./bin/example_gemm_xdl 0 1 5 151 ./bin/example_gemv_splitk 0 1 5 151
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
``` ```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} a_m_k: dim 2, lengths {1, 4608}, strides {4608, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} b_k_n: dim 2, lengths {4608, 1104}, strides {1, 4608}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {1, 1104}, strides {1104, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8} arg.a_grid_desc_kbatch_k0_m_k1_{1,4, 1, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8} arg.b_grid_desc_kbatch_k0_n_k1_{1,4, 1104, 8}
arg.c_grid_desc_m_n_{ 3840, 4096} arg.c_grid_desc_m_n_{ 1, 1104}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {1359, 1, 1}, block_dim {64, 1, 1}
Warm up Warm up
Start running 5 times... Start running 10 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s Perf: 0.0191358 ms, 0.531698 TFlops,532.295 GB/s
``` ```
...@@ -55,25 +55,28 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi ...@@ -55,25 +55,28 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
{ {
// use default case // use default case
} }
else if(argc == 4) else if(argc == 5)
{ {
config.do_verification = std::stoi(argv[1]); config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]); config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]); config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
} }
else if(argc == 10) else if(argc == 11)
{ {
config.do_verification = std::stoi(argv[1]); config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]); config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]); config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[4]); problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[5]); problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[6]); problem_size.K = std::stoi(argv[7]);
problem_size.StrideA = std::stoi(argv[7]); problem_size.stride_A = std::stoi(argv[8]);
problem_size.StrideB = std::stoi(argv[8]); problem_size.stride_B = std::stoi(argv[9]);
problem_size.StrideC = std::stoi(argv[9]); problem_size.stride_C = std::stoi(argv[10]);
} }
else else
{ {
......
...@@ -103,12 +103,13 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -103,12 +103,13 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
return true; return true;
} }
c_m_n_device_buf.Zero(); c_m_n_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
if(config.do_verification) if(config.do_verification)
{ {
auto ref_gemv = ReferenceGemmInstance{}; auto ref_gemv = ReferenceGemmInstance{};
auto ref_invoker = ref_gemv.MakeInvoker(); auto ref_invoker = ref_gemv.MakeInvoker();
...@@ -124,11 +125,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -124,11 +125,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>(); c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif #endif
} }
...@@ -146,13 +145,16 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -146,13 +145,16 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl; << gemv.GetTypeString() << std::endl;
return true; #ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
} }
bool run_gemv_example(int argc, char* argv[]) bool run_gemv_example(int argc, char* argv[])
{ {
ProblemSize problem_size; ProblemSize problem_size;
// problem_size.M = 1;
ExecutionConfig config; ExecutionConfig config;
if(argc == 1) if(argc == 1)
{ {
...@@ -185,7 +187,7 @@ bool run_gemv_example(int argc, char* argv[]) ...@@ -185,7 +187,7 @@ bool run_gemv_example(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: KBatch\n"); printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment