Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
...@@ -137,8 +137,8 @@ bool run_conv_bwd_data(const ExecutionConfig& config, ...@@ -137,8 +137,8 @@ bool run_conv_bwd_data(const ExecutionConfig& config,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
in_device_buf.FromDevice(in_device.mData.data()); in_device_buf.FromDevice(in_device.mData.data());
ck::utils::CorrectnessValidator validator;
return ck::utils::check_err(in_device.mData, in_host.mData); return validator.check_err(in_device.mData, in_host.mData);
} }
return true; return true;
......
...@@ -64,12 +64,13 @@ bool run_permute_bundle(const Problem& problem) ...@@ -64,12 +64,13 @@ bool run_permute_bundle(const Problem& problem)
{ {
return false; return false;
} }
ck::utils::CorrectnessValidator validator;
return ck::utils::check_err(output_bundle_tensor.AsSpan<const DataType>(), validator.check_err(output_bundle_tensor.AsSpan<const DataType>(),
output_tensor.AsSpan<const DataType>(), output_tensor.AsSpan<const DataType>(),
"Error: incorrect results in output tensor", "Error: incorrect results in output tensor",
1e-6, 1e-6,
1e-6); 1e-6);
return validator.is_success();
} }
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes) bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
......
...@@ -51,12 +51,13 @@ bool run_permute_element(const Problem& problem) ...@@ -51,12 +51,13 @@ bool run_permute_element(const Problem& problem)
{ {
return false; return false;
} }
ck::utils::CorrectnessValidator validator;
return ck::utils::check_err(output_tensor.AsSpan<const OutDataType>(), validator.check_err(output_tensor.AsSpan<const OutDataType>(),
output_tensor_host.AsSpan<const OutDataType>(), output_tensor_host.AsSpan<const OutDataType>(),
"Error: incorrect results in output tensor", "Error: incorrect results in output tensor",
1e-6, 1e-6,
1e-6); 1e-6);
return validator.is_success();
} }
bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes) bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
......
...@@ -160,11 +160,12 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -160,11 +160,12 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::CorrectnessValidator validator;
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op) int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op)
......
...@@ -148,11 +148,11 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -148,11 +148,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::CorrectnessValidator validator;
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op) int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op)
......
...@@ -150,11 +150,11 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -150,11 +150,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::CorrectnessValidator validator;
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op) int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op)
......
...@@ -132,11 +132,11 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -132,11 +132,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::CorrectnessValidator validator;
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op) int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op)
......
...@@ -256,8 +256,10 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -256,8 +256,10 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
out1_device_buf.FromDevice(out1_device.mData.data()); out1_device_buf.FromDevice(out1_device.mData.data());
#endif #endif
return ck::utils::check_err( ck::utils::CorrectnessValidator validator;
validator.check_err(
out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f); out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f);
return validator.is_success();
} }
return true; return true;
......
...@@ -89,7 +89,7 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -89,7 +89,7 @@ int run_groupnorm_example(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< device_instance.GetTypeString() << std::endl; << device_instance.GetTypeString() << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
{ {
Tensor<YDataType> host_y({N, H, W, G, C}); Tensor<YDataType> host_y({N, H, W, G, C});
using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm<XDataType, using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm<XDataType,
...@@ -106,8 +106,8 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -106,8 +106,8 @@ int run_groupnorm_example(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); validator.check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -398,9 +398,9 @@ int main(int argc, char* argv[]) ...@@ -398,9 +398,9 @@ int main(int argc, char* argv[])
cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx)); cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx));
}); });
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) ck::utils::CorrectnessValidator validator;
? 0 validator.check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData);
: 1; return !validator.is_success();
} }
return 0; return 0;
......
...@@ -398,9 +398,9 @@ int main(int argc, char* argv[]) ...@@ -398,9 +398,9 @@ int main(int argc, char* argv[])
cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx)); cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx));
}); });
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) ck::utils::CorrectnessValidator validator;
? 0 validator.check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData);
: 1; return validator.is_success();
} }
return 0; return 0;
......
...@@ -100,7 +100,7 @@ int main() ...@@ -100,7 +100,7 @@ int main()
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
if(do_verification) if(do_verification)
{ {
...@@ -108,9 +108,8 @@ int main() ...@@ -108,9 +108,8 @@ int main()
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}); host_elementwise4D(host_b, a, PassThrough{});
pass &= validator.check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -110,7 +110,7 @@ int main() ...@@ -110,7 +110,7 @@ int main()
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
if(do_verification) if(do_verification)
{ {
...@@ -122,9 +122,8 @@ int main() ...@@ -122,9 +122,8 @@ int main()
host_b, a, nchw, PassThrough{}); host_b, a, nchw, PassThrough{});
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &= validator.check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -156,7 +156,7 @@ int main() ...@@ -156,7 +156,7 @@ int main()
std::cout << "Bandwidth is : " << bandwidth << "GB/s . " << std::endl; std::cout << "Bandwidth is : " << bandwidth << "GB/s . " << std::endl;
std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl; std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
{ {
std::vector<std::size_t> mn = {static_cast<unsigned long>(M), std::vector<std::size_t> mn = {static_cast<unsigned long>(M),
static_cast<unsigned long>(N)}; static_cast<unsigned long>(N)};
...@@ -184,12 +184,11 @@ int main() ...@@ -184,12 +184,11 @@ int main()
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= validator.check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); if(!validator.is_success())
if(!(pass))
{ {
std::cout << "layernorm wrong" << std::endl; std::cout << "layernorm wrong" << std::endl;
} }
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -123,7 +123,9 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi ...@@ -123,7 +123,9 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); ck::utils::CorrectnessValidator validator;
validator.check_err(e_m_n_device_result, e_m_n_host_result);
return validator.is_success();
} }
return true; return true;
......
...@@ -396,13 +396,13 @@ int main(int argc, char* argv[]) ...@@ -396,13 +396,13 @@ int main(int argc, char* argv[])
double rtol = 1e-3; double rtol = 1e-3;
double atol = 1e-3; double atol = 1e-3;
return ck::utils::check_err(c_gs_ms_os_device_result.mData, ck::utils::CorrectnessValidator validator;
validator.check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData, c_gs_ms_os_host_result.mData,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol, rtol,
atol) atol);
? 0 return !validator.is_success();
: 1;
} }
return 0; return 0;
......
...@@ -183,16 +183,18 @@ bool pool3d_test(bool do_verification, ...@@ -183,16 +183,18 @@ bool pool3d_test(bool do_verification,
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_do_ho_wo_device, out_n_c_do_ho_wo_host); ck::utils::CorrectnessValidator validator;
validator.check_err(out_n_c_do_ho_wo_device, out_n_c_do_ho_wo_host);
if constexpr(OutputIndex) if constexpr(OutputIndex)
{ {
out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data()); out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device, ck::utils::CorrectnessValidator validator;
validator.check_err(out_indices_n_c_do_ho_wo_device,
out_indices_n_c_do_ho_wo_host); out_indices_n_c_do_ho_wo_host);
}; };
} }
return (pass); return validator.is_success();
}; };
...@@ -174,7 +174,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -174,7 +174,7 @@ bool maxpool_bwd_test(bool do_verification,
std::cout << "Pool fwd perf: " << ave_time_fwd << " ms" << std::endl; std::cout << "Pool fwd perf: " << ave_time_fwd << " ms" << std::endl;
std::cout << "Pool bwd perf: " << ave_time_bwd << " ms" << std::endl; std::cout << "Pool bwd perf: " << ave_time_bwd << " ms" << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
if(do_verification) if(do_verification)
{ {
...@@ -219,10 +219,10 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -219,10 +219,10 @@ bool maxpool_bwd_test(bool do_verification,
indices_device_buf.FromDevice(indices_n_c_ho_wo_device.mData.data()); indices_device_buf.FromDevice(indices_n_c_ho_wo_device.mData.data());
din_device_buf.FromDevice(din_n_c_hi_wi_device.mData.data()); din_device_buf.FromDevice(din_n_c_hi_wi_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); validator.check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host);
pass = pass && ck::utils::check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host); validator.check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host);
pass = pass && ck::utils::check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host); validator.check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host);
} }
return (pass); return validator.is_success();
}; };
...@@ -69,7 +69,7 @@ int main() ...@@ -69,7 +69,7 @@ int main()
std::cout << "perf: " << ave_time << " ms" << std::endl; std::cout << "perf: " << ave_time << " ms" << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
if(do_verification) if(do_verification)
{ {
Tensor<YDataType> y_host(HostTensorDescriptor{N}); Tensor<YDataType> y_host(HostTensorDescriptor{N});
...@@ -81,8 +81,8 @@ int main() ...@@ -81,8 +81,8 @@ int main()
} }
y_device_buf.FromDevice(y.mData.data()); y_device_buf.FromDevice(y.mData.data());
pass = ck::utils::check_err(y, y_host); validator.check_err(y, y_host);
} }
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -120,7 +120,7 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -120,7 +120,7 @@ bool pool3d_bwd_test(bool do_verification,
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << std::endl; std::cout << "Perf: " << ave_time << std::endl;
bool pass = true; ck::utils::CorrectnessValidator validator;
if(do_verification) if(do_verification)
{ {
...@@ -140,8 +140,8 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -140,8 +140,8 @@ bool pool3d_bwd_test(bool do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
din_device_buf.FromDevice(din_dev.mData.data()); din_device_buf.FromDevice(din_dev.mData.data());
pass = ck::utils::check_err(din_dev, din_host); validator.check_err(din_dev, din_host);
} }
return pass; return validator.is_success();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment