Commit 27997410 authored by Umang Yadav's avatar Umang Yadav
Browse files

rename threshold to tolerance

parent 97226839
...@@ -536,7 +536,7 @@ struct params : command<params> ...@@ -536,7 +536,7 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
compiler c; compiler c;
migraphx::verify::threshold tols; migraphx::verify::tolerance tols;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
......
...@@ -77,7 +77,7 @@ void verify_program(const std::string& name, ...@@ -77,7 +77,7 @@ void verify_program(const std::string& name,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
verify::threshold tols) verify::tolerance tols)
{ {
auto x = run_ref(p, inputs); auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs); auto y = run_target(p, t, options, quantize, inputs);
...@@ -93,7 +93,7 @@ void verify_instructions(const program& prog, ...@@ -93,7 +93,7 @@ void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize, precision quantize,
verify::threshold tols) verify::tolerance tols)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog)) for(auto&& ins : (*mm_prog))
...@@ -140,7 +140,7 @@ void verify_reduced(program p, ...@@ -140,7 +140,7 @@ void verify_reduced(program p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
verify::threshold tols) verify::tolerance tols)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n + 1);
...@@ -155,7 +155,7 @@ void verify_reduced_program(const program& p, ...@@ -155,7 +155,7 @@ void verify_reduced_program(const program& p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
verify::threshold tols) verify::tolerance tols)
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
......
...@@ -38,18 +38,18 @@ void verify_program(const std::string& name, ...@@ -38,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
verify::threshold tols = verify::threshold{}); verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
verify::threshold tols = verify::threshold{}); verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
verify::threshold tols = verify::threshold{}); verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -214,7 +214,7 @@ struct expected ...@@ -214,7 +214,7 @@ struct expected
template <class T> template <class T>
expected(const T&) -> expected<T>; expected(const T&) -> expected<T>;
struct threshold struct tolerance
{ {
double rms_tol = 0.001; double rms_tol = 0.001;
double atol = 0.001; double atol = 0.001;
...@@ -222,7 +222,7 @@ struct threshold ...@@ -222,7 +222,7 @@ struct threshold
}; };
template <class R1, class R2> template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, threshold thres) bool allclose(const R1& r1, const R2& r2, tolerance thres)
{ {
std::size_t n = range_distance(r1); std::size_t n = range_distance(r1);
if(n == range_distance(r2)) if(n == range_distance(r2))
...@@ -256,7 +256,7 @@ bool verify_range(const R1& r1, ...@@ -256,7 +256,7 @@ bool verify_range(const R1& r1,
template <class R1, class R2> template <class R1, class R2>
bool verify_range_with_threshold(const R1& r1, bool verify_range_with_threshold(const R1& r1,
const expected<R2>& r2, const expected<R2>& r2,
threshold tols = threshold{}, tolerance tols = tolerance{},
double* out_error = nullptr) double* out_error = nullptr)
{ {
auto rms_error = rms_range(r1, r2.data()); auto rms_error = rms_range(r1, r2.data());
...@@ -272,7 +272,7 @@ bool verify_range_with_threshold(const R1& r1, ...@@ -272,7 +272,7 @@ bool verify_range_with_threshold(const R1& r1,
template <class R1, class R2> template <class R1, class R2>
bool verify_range_with_threshold(const expected<R1>& r1, bool verify_range_with_threshold(const expected<R1>& r1,
const R2& r2, const R2& r2,
threshold tols = threshold{}, tolerance tols = tolerance{},
double* out_error = nullptr) double* out_error = nullptr)
{ {
return verify_range(r2, r1, tols, out_error); return verify_range(r2, r1, tols, out_error);
......
...@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT bool verify_args_with_threshold(const std::string& name, MIGRAPHX_EXPORT bool verify_args_with_threshold(const std::string& name,
const argument& ref_arg, const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
verify::threshold); verify::tolerance);
MIGRAPHX_EXPORT bool verify_args(const std::string& name, MIGRAPHX_EXPORT bool verify_args(const std::string& name,
const argument& target_arg, const argument& target_arg,
......
...@@ -30,7 +30,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -30,7 +30,7 @@ inline namespace MIGRAPHX_INLINE_NS {
bool verify_args_with_threshold(const std::string& name, bool verify_args_with_threshold(const std::string& name,
const argument& target_arg, const argument& target_arg,
const argument& ref_arg, const argument& ref_arg,
verify::threshold tols) verify::tolerance tols)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
...@@ -100,7 +100,7 @@ bool verify_args(const std::string& name, ...@@ -100,7 +100,7 @@ bool verify_args(const std::string& name,
{ {
double rms_tol = 0.001; double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); }); target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::threshold tols{rms_tol}; verify::tolerance tols{rms_tol};
return verify_args_with_threshold(name, target_arg, ref_arg, tols); return verify_args_with_threshold(name, target_arg, ref_arg, tols);
} }
......
...@@ -120,7 +120,7 @@ TEST_CASE(int8_quantization) ...@@ -120,7 +120,7 @@ TEST_CASE(int8_quantization)
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
gpu_result, gpu_result,
migraphx::verify::expected{ref_result}, migraphx::verify::expected{ref_result},
migraphx::verify::threshold{0.01})); migraphx::verify::tolerance{0.01}));
else else
EXPECT(migraphx::verify::verify_range(gpu_result, ref_result)); EXPECT(migraphx::verify::verify_range(gpu_result, ref_result));
} }
......
...@@ -1080,7 +1080,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1080,7 +1080,7 @@ TEST_CASE(int8_quantization_dot)
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
quant_result, quant_result,
migraphx::verify::expected{no_quant_result}, migraphx::verify::expected{no_quant_result},
migraphx::verify::threshold{0.003})); migraphx::verify::tolerance{0.003}));
} }
} }
......
...@@ -79,7 +79,7 @@ void dot_2d_test() ...@@ -79,7 +79,7 @@ void dot_2d_test()
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::threshold{9e-6})); results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
} }
TEST_CASE_REGISTER(dot_2d_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
...@@ -131,7 +131,7 @@ void dot_4d_test() ...@@ -131,7 +131,7 @@ void dot_4d_test()
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::threshold{9e-6})); results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
} }
TEST_CASE_REGISTER(dot_4d_test<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
......
...@@ -79,5 +79,5 @@ TEST_CASE(multinomial_test) ...@@ -79,5 +79,5 @@ TEST_CASE(multinomial_test)
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::threshold{0.01})); res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
} }
...@@ -70,7 +70,7 @@ TEST_CASE(random_uniform_test) ...@@ -70,7 +70,7 @@ TEST_CASE(random_uniform_test)
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range_with_threshold(result_vec, EXPECT(migraphx::verify::verify_range_with_threshold(result_vec,
migraphx::verify::expected{rand_samples}, migraphx::verify::expected{rand_samples},
migraphx::verify::threshold{0.00001})); migraphx::verify::tolerance{0.00001}));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
......
...@@ -1009,7 +1009,9 @@ TEST_CASE(rnn_fp16) ...@@ -1009,7 +1009,9 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
last_output_data, migraphx::verify::expected{last_output_data_gold}, migraphx::verify::threshold{0.005})); last_output_data,
migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005}));
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -2985,7 +2987,7 @@ TEST_CASE(gru_fp16) ...@@ -2985,7 +2987,7 @@ TEST_CASE(gru_fp16)
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_threshold(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::threshold{0.005})); hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment