Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
......@@ -137,8 +137,8 @@ bool run_conv_bwd_data(const ExecutionConfig& config,
ref_invoker.Run(ref_argument);
in_device_buf.FromDevice(in_device.mData.data());
return ck::utils::check_err(in_device.mData, in_host.mData);
ck::utils::CorrectnessValidator validator;
return validator.check_err(in_device.mData, in_host.mData);
}
return true;
......
......@@ -64,12 +64,13 @@ bool run_permute_bundle(const Problem& problem)
{
return false;
}
return ck::utils::check_err(output_bundle_tensor.AsSpan<const DataType>(),
ck::utils::CorrectnessValidator validator;
validator.check_err(output_bundle_tensor.AsSpan<const DataType>(),
output_tensor.AsSpan<const DataType>(),
"Error: incorrect results in output tensor",
1e-6,
1e-6);
return validator.is_success();
}
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
......
......@@ -51,12 +51,13 @@ bool run_permute_element(const Problem& problem)
{
return false;
}
return ck::utils::check_err(output_tensor.AsSpan<const OutDataType>(),
ck::utils::CorrectnessValidator validator;
validator.check_err(output_tensor.AsSpan<const OutDataType>(),
output_tensor_host.AsSpan<const OutDataType>(),
"Error: incorrect results in output tensor",
1e-6,
1e-6);
return validator.is_success();
}
bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
......
......@@ -160,11 +160,12 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
ck::utils::CorrectnessValidator validator;
validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
return !validator.is_success();
}
int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op)
......
......@@ -148,11 +148,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
ck::utils::CorrectnessValidator validator;
validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
return !validator.is_success();
}
int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op)
......
......@@ -150,11 +150,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
ck::utils::CorrectnessValidator validator;
validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
return !validator.is_success();
}
int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op)
......
......@@ -132,11 +132,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
ck::utils::CorrectnessValidator validator;
validator.check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
return !validator.is_success();
}
int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op)
......
......@@ -256,8 +256,10 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
out1_device_buf.FromDevice(out1_device.mData.data());
#endif
return ck::utils::check_err(
ck::utils::CorrectnessValidator validator;
validator.check_err(
out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f);
return validator.is_success();
}
return true;
......
......@@ -89,7 +89,7 @@ int run_groupnorm_example(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< device_instance.GetTypeString() << std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
{
Tensor<YDataType> host_y({N, H, W, G, C});
using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm<XDataType,
......@@ -106,8 +106,8 @@ int run_groupnorm_example(int argc, char* argv[])
ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
validator.check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
}
return (pass ? 0 : 1);
return !validator.is_success();
}
......@@ -398,9 +398,9 @@ int main(int argc, char* argv[])
cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx));
});
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
ck::utils::CorrectnessValidator validator;
validator.check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData);
return !validator.is_success();
}
return 0;
......
......@@ -398,9 +398,9 @@ int main(int argc, char* argv[])
cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx));
});
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
ck::utils::CorrectnessValidator validator;
validator.check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData);
return validator.is_success();
}
return 0;
......
......@@ -100,7 +100,7 @@ int main()
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
if(do_verification)
{
......@@ -108,9 +108,8 @@ int main()
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{});
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
validator.check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
return !validator.is_success();
}
......@@ -110,7 +110,7 @@ int main()
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
if(do_verification)
{
......@@ -122,9 +122,8 @@ int main()
host_b, a, nchw, PassThrough{});
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
validator.check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
return !validator.is_success();
}
......@@ -156,7 +156,7 @@ int main()
std::cout << "Bandwidth is : " << bandwidth << "GB/s . " << std::endl;
std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
{
std::vector<std::size_t> mn = {static_cast<unsigned long>(M),
static_cast<unsigned long>(N)};
......@@ -184,12 +184,11 @@ int main()
ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data());
pass &=
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
if(!(pass))
validator.check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
if(!validator.is_success())
{
std::cout << "layernorm wrong" << std::endl;
}
}
return (pass ? 0 : 1);
return !validator.is_success();
}
......@@ -123,7 +123,9 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
ck::utils::CorrectnessValidator validator;
validator.check_err(e_m_n_device_result, e_m_n_host_result);
return validator.is_success();
}
return true;
......
......@@ -396,13 +396,13 @@ int main(int argc, char* argv[])
double rtol = 1e-3;
double atol = 1e-3;
return ck::utils::check_err(c_gs_ms_os_device_result.mData,
ck::utils::CorrectnessValidator validator;
validator.check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol)
? 0
: 1;
atol);
return !validator.is_success();
}
return 0;
......
......@@ -183,16 +183,18 @@ bool pool3d_test(bool do_verification,
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_do_ho_wo_device, out_n_c_do_ho_wo_host);
ck::utils::CorrectnessValidator validator;
validator.check_err(out_n_c_do_ho_wo_device, out_n_c_do_ho_wo_host);
if constexpr(OutputIndex)
{
out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device,
ck::utils::CorrectnessValidator validator;
validator.check_err(out_indices_n_c_do_ho_wo_device,
out_indices_n_c_do_ho_wo_host);
};
}
return (pass);
return validator.is_success();
};
......@@ -174,7 +174,7 @@ bool maxpool_bwd_test(bool do_verification,
std::cout << "Pool fwd perf: " << ave_time_fwd << " ms" << std::endl;
std::cout << "Pool bwd perf: " << ave_time_bwd << " ms" << std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
if(do_verification)
{
......@@ -219,10 +219,10 @@ bool maxpool_bwd_test(bool do_verification,
indices_device_buf.FromDevice(indices_n_c_ho_wo_device.mData.data());
din_device_buf.FromDevice(din_n_c_hi_wi_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host);
pass = pass && ck::utils::check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host);
pass = pass && ck::utils::check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host);
validator.check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host);
validator.check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host);
validator.check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host);
}
return (pass);
return validator.is_success();
};
......@@ -69,7 +69,7 @@ int main()
std::cout << "perf: " << ave_time << " ms" << std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
if(do_verification)
{
Tensor<YDataType> y_host(HostTensorDescriptor{N});
......@@ -81,8 +81,8 @@ int main()
}
y_device_buf.FromDevice(y.mData.data());
pass = ck::utils::check_err(y, y_host);
validator.check_err(y, y_host);
}
return (pass ? 0 : 1);
return !validator.is_success();
}
......@@ -120,7 +120,7 @@ bool pool3d_bwd_test(bool do_verification,
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << std::endl;
bool pass = true;
ck::utils::CorrectnessValidator validator;
if(do_verification)
{
......@@ -140,8 +140,8 @@ bool pool3d_bwd_test(bool do_verification,
ref_invoker.Run(ref_argument);
din_device_buf.FromDevice(din_dev.mData.data());
pass = ck::utils::check_err(din_dev, din_host);
validator.check_err(din_dev, din_host);
}
return pass;
return validator.is_success();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment