Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
...@@ -250,7 +250,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, ...@@ -250,7 +250,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
float best_ave_time = std::numeric_limits<float>::max(); float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
int num_kernel = 0; int num_kernel = 0;
// profile device operation instances // profile device operation instances
...@@ -316,7 +316,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, ...@@ -316,7 +316,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
{ {
h_device_buf.FromDevice(h_m_n.mData.data()); h_device_buf.FromDevice(h_m_n.mData.data());
pass = pass && ck::utils::check_err( validator.check_err(
h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2); h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
} }
} }
...@@ -327,19 +327,14 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, ...@@ -327,19 +327,14 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
} }
} }
if(num_kernel == 0) if(num_kernel != 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
pass = false;
}
else
{ {
if(time_kernel) if(time_kernel)
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
} }
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -281,6 +281,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -281,6 +281,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
{ {
...@@ -343,9 +345,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -343,9 +345,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); validator.check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); validator.check_err(reduce1_m_device_result, reduce1_m_host_result);
if(do_log) if(do_log)
{ {
...@@ -376,6 +378,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -376,6 +378,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
} }
} }
validator.is_success();
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
} }
......
...@@ -158,7 +158,7 @@ bool profile_gemm_bilinear_impl(int do_verification, ...@@ -158,7 +158,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
// profile device operation instances // profile device operation instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -215,7 +215,7 @@ bool profile_gemm_bilinear_impl(int do_verification, ...@@ -215,7 +215,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); validator.check_err(e_m_n_device_result, e_m_n_host_result);
} }
} }
else else
...@@ -227,7 +227,7 @@ bool profile_gemm_bilinear_impl(int do_verification, ...@@ -227,7 +227,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -147,7 +147,7 @@ bool profile_gemm_fastgelu_impl(int do_verification, ...@@ -147,7 +147,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
// profile device operation instances // profile device operation instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -203,7 +203,7 @@ bool profile_gemm_fastgelu_impl(int do_verification, ...@@ -203,7 +203,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); validator.check_err(e_m_n_device_result, e_m_n_host_result);
} }
} }
else else
...@@ -215,7 +215,7 @@ bool profile_gemm_fastgelu_impl(int do_verification, ...@@ -215,7 +215,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -42,7 +42,7 @@ int profile_gemm_impl(int do_verification, ...@@ -42,7 +42,7 @@ int profile_gemm_impl(int do_verification,
int StrideB, int StrideB,
int StrideC) int StrideC)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -188,7 +188,7 @@ int profile_gemm_impl(int do_verification, ...@@ -188,7 +188,7 @@ int profile_gemm_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log) if(do_log)
{ {
...@@ -247,7 +247,7 @@ int profile_gemm_impl(int do_verification, ...@@ -247,7 +247,7 @@ int profile_gemm_impl(int do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
return pass ? 0 : 1; return !validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -250,6 +250,8 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -250,6 +250,8 @@ bool profile_gemm_reduce_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
{ {
...@@ -310,9 +312,10 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -310,9 +312,10 @@ bool profile_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); validator.check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); validator.check_err(reduce1_m_device_result, reduce1_m_host_result);
validator.is_success();
if(do_log) if(do_log)
{ {
...@@ -346,7 +349,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -346,7 +349,7 @@ bool profile_gemm_reduce_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -43,7 +43,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -43,7 +43,7 @@ bool profile_gemm_splitk_impl(int do_verification,
int StrideC, int StrideC,
int KBatch) int KBatch)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -181,7 +181,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -181,7 +181,7 @@ bool profile_gemm_splitk_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log) if(do_log)
{ {
...@@ -221,12 +221,12 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -221,12 +221,12 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string msg = "Error: Incorrect results!"; std::string msg = "Error: Incorrect results!";
double rtol = 1e-1; double rtol = 1e-1;
double atol = 1e-1; double atol = 1e-1;
pass = pass & ck::utils::check_err( validator.check_err(
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
} }
else else
{ {
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
} }
if(tflops > best_tflops) if(tflops > best_tflops)
...@@ -286,7 +286,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -286,7 +286,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl; << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -43,7 +43,7 @@ bool profile_gemm_streamk_impl(int do_verification, ...@@ -43,7 +43,7 @@ bool profile_gemm_streamk_impl(int do_verification,
int StrideC, int StrideC,
uint32_t NumSKBlocks = 0xffffffff) uint32_t NumSKBlocks = 0xffffffff)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -176,7 +176,7 @@ bool profile_gemm_streamk_impl(int do_verification, ...@@ -176,7 +176,7 @@ bool profile_gemm_streamk_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log) if(do_log)
{ {
...@@ -260,7 +260,7 @@ bool profile_gemm_streamk_impl(int do_verification, ...@@ -260,7 +260,7 @@ bool profile_gemm_streamk_impl(int do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -122,7 +122,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, ...@@ -122,7 +122,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device op instances // profile device op instances
bool pass = true; ck::utils::CorrectnessValidator validator;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
...@@ -159,7 +159,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, ...@@ -159,7 +159,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
{ {
in_device_buf.FromDevice(in_device.mData.data()); in_device_buf.FromDevice(in_device.mData.data());
pass = pass & ck::utils::check_err(in_device, in_host); validator.check_err(in_device, in_host);
if(do_log) if(do_log)
{ {
...@@ -250,7 +250,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, ...@@ -250,7 +250,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -160,6 +160,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -160,6 +160,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads)); range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads)); range_copy(conv_param.input_right_pads_, begin(input_right_pads));
ck::utils::CorrectnessValidator validator;
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
...@@ -214,15 +215,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -214,15 +215,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
{ {
wei_device_buf.FromDevice(weight_device_result.mData.data()); wei_device_buf.FromDevice(weight_device_result.mData.data());
bool pass = ck::utils::check_err(weight_device_result, weight_host_result); validator.check_err(weight_device_result, weight_host_result);
if(!pass) if(!validator.is_success())
{ {
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl; std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
} }
all_pass &= pass;
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl;
...@@ -250,7 +249,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -250,7 +249,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return all_pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -142,7 +142,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -142,7 +142,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device op instances // profile device op instances
bool pass = true; ck::utils::CorrectnessValidator validator;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
...@@ -179,7 +179,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -179,7 +179,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
{ {
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, host_output); validator.check_err(device_output, host_output);
if(do_log) if(do_log)
{ {
...@@ -246,7 +246,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -246,7 +246,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -39,7 +39,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification, ...@@ -39,7 +39,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
const std::vector<int>& StrideCs) const std::vector<int>& StrideCs)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -238,8 +238,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification, ...@@ -238,8 +238,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
bool group_pass = bool group_pass =
ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); validator.check_err(c_m_n_device_results[i], c_m_n_host_result);
pass = pass && group_pass;
std::cout << "group: " << i << " verification result: " << std::boolalpha std::cout << "group: " << i << " verification result: " << std::boolalpha
<< group_pass << std::endl; << group_pass << std::endl;
...@@ -267,13 +266,13 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification, ...@@ -267,13 +266,13 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << std::endl; std::cout << "Verification: " << (validator.is_success() ? "SUCCESS" : "FAILURE") << std::endl;
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -44,7 +44,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -44,7 +44,7 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs,
int kbatch = 1) int kbatch = 1)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -274,7 +274,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -274,7 +274,7 @@ bool profile_grouped_gemm_impl(int do_verification,
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1) if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{ {
instance_pass = instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i], instance_pass && validator.check_err(c_m_n_device_results[i],
c_m_n_host_results[i], c_m_n_host_results[i],
"Error: Incorrect results!", "Error: Incorrect results!",
0.06); 0.06);
...@@ -282,7 +282,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -282,7 +282,7 @@ bool profile_grouped_gemm_impl(int do_verification,
else else
{ {
instance_pass = instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i], instance_pass && validator.check_err(c_m_n_device_results[i],
c_m_n_host_results[i]); c_m_n_host_results[i]);
} }
...@@ -303,8 +303,6 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -303,8 +303,6 @@ bool profile_grouped_gemm_impl(int do_verification,
std::cout << "Instance: " << gemm_name << " verification " std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl; << (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
} }
float ave_time = float ave_time =
...@@ -354,7 +352,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -354,7 +352,7 @@ bool profile_grouped_gemm_impl(int do_verification,
<< std::endl; << std::endl;
} }
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -110,7 +110,7 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -110,7 +110,7 @@ bool profile_groupnorm_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -169,7 +169,7 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -169,7 +169,7 @@ bool profile_groupnorm_impl(int do_verification,
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); bool pass = validator.check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log) if(do_log)
{ {
...@@ -182,7 +182,6 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -182,7 +182,6 @@ bool profile_groupnorm_impl(int do_verification,
{ {
std::cout << inst_ptr->GetTypeString() << " failed verification: "; std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
} }
else else
{ {
...@@ -198,14 +197,7 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -198,14 +197,7 @@ bool profile_groupnorm_impl(int do_verification,
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl; << best_instance_name << std::endl;
} }
return validator.is_success();
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
} }
} // namespace profiler } // namespace profiler
......
...@@ -121,7 +121,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -121,7 +121,7 @@ bool profile_layernorm_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -186,7 +186,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -186,7 +186,7 @@ bool profile_layernorm_impl(int do_verification,
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); validator.check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log) if(do_log)
{ {
...@@ -199,7 +199,6 @@ bool profile_layernorm_impl(int do_verification, ...@@ -199,7 +199,6 @@ bool profile_layernorm_impl(int do_verification,
{ {
std::cout << inst_ptr->GetTypeString() << " failed verification: "; std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
} }
else else
{ {
...@@ -218,13 +217,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -218,13 +217,7 @@ bool profile_layernorm_impl(int do_verification,
<< best_instance_name << std::endl; << best_instance_name << std::endl;
} }
if(num_kernel == 0) return validator.is_success();
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
} }
} // namespace profiler } // namespace profiler
......
...@@ -150,7 +150,7 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -150,7 +150,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -213,7 +213,7 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -213,7 +213,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, bool pass = validator.check_err(out_n_c_do_ho_wo_device.mData,
out_n_c_do_ho_wo_host.mData, out_n_c_do_ho_wo_host.mData,
"Error: Incorrect results", "Error: Incorrect results",
1e-3, 1e-3,
...@@ -223,7 +223,7 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -223,7 +223,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{ {
out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data()); out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device, pass = pass && validator.check_err(out_indices_n_c_do_ho_wo_device,
out_indices_n_c_do_ho_wo_host); out_indices_n_c_do_ho_wo_host);
} }
...@@ -250,7 +250,6 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -250,7 +250,6 @@ bool profile_pool3d_fwd_impl(int do_verification,
{ {
std::cout << inst_ptr->GetTypeString() << " failed verification: "; std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl; LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl;
return false;
} }
else else
{ {
...@@ -267,13 +266,7 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -267,13 +266,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
<< best_instance_name << std::endl; << best_instance_name << std::endl;
} }
if(num_kernel == 0) return validator.is_success();
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
} }
} // namespace profiler } // namespace profiler
......
...@@ -196,7 +196,7 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -196,7 +196,7 @@ bool profile_reduce_impl_impl(bool do_verification,
invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
int num_kernel = 0; int num_kernel = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
if constexpr(!invalid_reduce) if constexpr(!invalid_reduce)
{ {
...@@ -403,12 +403,12 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -403,12 +403,12 @@ bool profile_reduce_impl_impl(bool do_verification,
bool single_pass; bool single_pass;
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
single_pass = ck::utils::check_err(out, out_ref); single_pass = validator.check_err(out, out_ref);
if(OutputIndex) if(OutputIndex)
{ {
out_indices_dev.FromDevice(out_indices.mData.data()); out_indices_dev.FromDevice(out_indices.mData.data());
single_pass = single_pass && ck::utils::check_err(out_indices, out_indices_ref); single_pass = single_pass && validator.check_err(out_indices, out_indices_ref);
}; };
if(!single_pass) if(!single_pass)
...@@ -416,7 +416,6 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -416,7 +416,6 @@ bool profile_reduce_impl_impl(bool do_verification,
std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl;
} }
pass = pass && single_pass;
}; };
if(do_dumpout) if(do_dumpout)
...@@ -447,13 +446,7 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -447,13 +446,7 @@ bool profile_reduce_impl_impl(bool do_verification,
"The requested reduction operation is not supported, please check!"); "The requested reduction operation is not supported, please check!");
}; };
if(num_kernel == 0) return validator.is_success();
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
};
return pass;
}; };
template <typename InDataType, typename AccDataType, typename OutDataType> template <typename InDataType, typename AccDataType, typename OutDataType>
......
...@@ -123,7 +123,7 @@ bool profile_softmax_impl(int do_verification, ...@@ -123,7 +123,7 @@ bool profile_softmax_impl(int do_verification,
std::string best_instance_name; std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max(); float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
std::vector<bool> instance_pass; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instances) for(auto& inst_ptr : instances)
{ {
...@@ -176,7 +176,7 @@ bool profile_softmax_impl(int do_verification, ...@@ -176,7 +176,7 @@ bool profile_softmax_impl(int do_verification,
bool pass = true; bool pass = true;
if(std::is_same<InDataType, int8_t>::value) if(std::is_same<InDataType, int8_t>::value)
{ {
pass = pass && ck::utils::check_err( pass = pass && validator.check_err(
out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1); out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1);
if(do_log) if(do_log)
{ {
...@@ -188,7 +188,7 @@ bool profile_softmax_impl(int do_verification, ...@@ -188,7 +188,7 @@ bool profile_softmax_impl(int do_verification,
} }
else else
{ {
pass = pass && ck::utils::check_err(out.mData, out_ref.mData); pass = pass && validator.check_err(out.mData, out_ref.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
...@@ -219,8 +219,7 @@ bool profile_softmax_impl(int do_verification, ...@@ -219,8 +219,7 @@ bool profile_softmax_impl(int do_verification,
<< "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec
<< " GB/s, " << best_instance_name << std::endl; << " GB/s, " << best_instance_name << std::endl;
} }
return std::all_of( return validator.is_success();
std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; });
} }
} // namespace profiler } // namespace profiler
......
...@@ -46,110 +46,113 @@ class TestConvUtil : public ::testing::Test ...@@ -46,110 +46,113 @@ class TestConvUtil : public ::testing::Test
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
{ {
ck::utils::CorrectnessValidator validator;
// stride 2, dilation 1, pad 1 // stride 2, dilation 1, pad 1
SetNDParams(1, 2, 1, 1); SetNDParams(1, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths(); std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D.")); out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
// stride 1, dilation 1, pad 1 // stride 1, dilation 1, pad 1
SetNDParams(1, 1, 1, 1); SetNDParams(1, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}.")); out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
// stride 2, dilation 1, pad 2 // stride 2, dilation 1, pad 2
SetNDParams(1, 2, 1, 2); SetNDParams(1, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{37}, std::vector<ck::index_t>{37},
"Error: ConvParams 1D padding left/right {2}.")); "Error: ConvParams 1D padding left/right {2}."));
// stride 2, dilation 2, pad 2 // stride 2, dilation 2, pad 2
SetNDParams(1, 2, 2, 2); SetNDParams(1, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}.")); out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
// stride 3, dilation 2, pad 1 // stride 3, dilation 2, pad 1
SetNDParams(1, 3, 2, 1); SetNDParams(1, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE( EXPECT_TRUE(
ck::utils::check_err(out_spatial_len, validator.check_err(out_spatial_len,
std::vector<ck::index_t>{23}, std::vector<ck::index_t>{23},
"Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
} }
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
{ {
ck::utils::CorrectnessValidator validator;
// stride 2, dilation 1, pad 1 // stride 2, dilation 1, pad 1
SetNDParams(2, 2, 1, 1); SetNDParams(2, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths(); std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36}, std::vector<ck::index_t>{36, 36},
"Error: ConvParams 2D default constructor.")); "Error: ConvParams 2D default constructor."));
// stride 1, dilation 1, pad 1 // stride 1, dilation 1, pad 1
SetNDParams(2, 1, 1, 1); SetNDParams(2, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}.")); out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
// stride 2, dilation 1, pad 2 // stride 2, dilation 1, pad 2
SetNDParams(2, 2, 1, 2); SetNDParams(2, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37}, std::vector<ck::index_t>{37, 37},
"Error: ConvParams 2D padding left/right {2,2}.")); "Error: ConvParams 2D padding left/right {2,2}."));
// stride 2, dilation 2, pad 2 // stride 2, dilation 2, pad 2
SetNDParams(2, 2, 2, 2); SetNDParams(2, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
// stride 3, dilation 2, pad 1 // stride 3, dilation 2, pad 1
SetNDParams(2, 3, 2, 1); SetNDParams(2, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE( EXPECT_TRUE(
ck::utils::check_err(out_spatial_len, validator.check_err(out_spatial_len,
std::vector<ck::index_t>{23, 23}, std::vector<ck::index_t>{23, 23},
"Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
} }
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{ {
ck::utils::CorrectnessValidator validator;
// stride 2, dilation 1, pad 1 // stride 2, dilation 1, pad 1
SetNDParams(3, 2, 1, 1); SetNDParams(3, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths(); std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D.")); out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));
// stride 1, dilation 1, pad 1 // stride 1, dilation 1, pad 1
SetNDParams(3, 1, 1, 1); SetNDParams(3, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{71, 71, 71}, std::vector<ck::index_t>{71, 71, 71},
"Error: ConvParams 3D stride {1, 1, 1}.")); "Error: ConvParams 3D stride {1, 1, 1}."));
// stride 2, dilation 1, pad 2 // stride 2, dilation 1, pad 2
SetNDParams(3, 2, 1, 2); SetNDParams(3, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37, 37}, std::vector<ck::index_t>{37, 37, 37},
"Error: ConvParams 3D padding left/right {2, 2, 2}.")); "Error: ConvParams 3D padding left/right {2, 2, 2}."));
// stride 2, dilation 2, pad 2 // stride 2, dilation 2, pad 2
SetNDParams(3, 2, 2, 2); SetNDParams(3, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36, 36}, std::vector<ck::index_t>{36, 36, 36},
"Error: ConvParams 3D dilation {2, 2, 2}.")); "Error: ConvParams 3D dilation {2, 2, 2}."));
// stride 3, dilation 2, pad 1 // stride 3, dilation 2, pad 1
SetNDParams(3, 3, 2, 1); SetNDParams(3, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(validator.check_err(
out_spatial_len, out_spatial_len,
std::vector<ck::index_t>{23, 23, 23}, std::vector<ck::index_t>{23, 23, 23},
"Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
......
...@@ -227,30 +227,30 @@ struct TestGemm ...@@ -227,30 +227,30 @@ struct TestGemm
if(is_supported && do_verification) if(is_supported && do_verification)
{ {
// Assert // Assert
bool res = false; ck::utils::CorrectnessValidator validator;
if(std::is_same<CDataType, float>::value) if(std::is_same<CDataType, float>::value)
{ {
res = ck::utils::check_err(c_device, c_host); res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, ck::half_t>::value) else if(std::is_same<CDataType, ck::half_t>::value)
{ {
res = ck::utils::check_err(c_device, c_host); res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, ck::bhalf_t>::value) else if(std::is_same<CDataType, ck::bhalf_t>::value)
{ {
res = ck::utils::check_err(c_device, c_host); res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, int8_t>::value) else if(std::is_same<CDataType, int8_t>::value)
{ {
res = ck::utils::check_err(c_device, c_host); res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, double>::value) else if(std::is_same<CDataType, double>::value)
{ {
res = ck::utils::check_err(c_device, c_host); res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment