Commit 27997410 authored by Umang Yadav's avatar Umang Yadav
Browse files

rename threshold to tolerance

parent 97226839
......@@ -536,7 +536,7 @@ struct params : command<params>
struct verify : command<verify>
{
compiler c;
migraphx::verify::threshold tols;
migraphx::verify::tolerance tols;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
......
......@@ -77,7 +77,7 @@ void verify_program(const std::string& name,
compile_options options,
precision quantize,
const parameter_map& inputs,
verify::threshold tols)
verify::tolerance tols)
{
auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs);
......@@ -93,7 +93,7 @@ void verify_instructions(const program& prog,
const target& t,
compile_options options,
precision quantize,
verify::threshold tols)
verify::tolerance tols)
{
const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog))
......@@ -140,7 +140,7 @@ void verify_reduced(program p,
compile_options options,
precision quantize,
const parameter_map& inputs,
verify::threshold tols)
verify::tolerance tols)
{
auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1);
......@@ -155,7 +155,7 @@ void verify_reduced_program(const program& p,
compile_options options,
precision quantize,
const parameter_map& inputs,
verify::threshold tols)
verify::tolerance tols)
{
const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
......
......@@ -38,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {},
verify::threshold tols = verify::threshold{});
verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
verify::threshold tols = verify::threshold{});
verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {},
verify::threshold tols = verify::threshold{});
verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
......
......@@ -214,7 +214,7 @@ struct expected
template <class T>
expected(const T&) -> expected<T>;
struct threshold
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
......@@ -222,7 +222,7 @@ struct threshold
};
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, threshold thres)
bool allclose(const R1& r1, const R2& r2, tolerance thres)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
......@@ -256,7 +256,7 @@ bool verify_range(const R1& r1,
template <class R1, class R2>
bool verify_range_with_threshold(const R1& r1,
const expected<R2>& r2,
threshold tols = threshold{},
tolerance tols = tolerance{},
double* out_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
......@@ -272,7 +272,7 @@ bool verify_range_with_threshold(const R1& r1,
template <class R1, class R2>
bool verify_range_with_threshold(const expected<R1>& r1,
const R2& r2,
threshold tols = threshold{},
tolerance tols = tolerance{},
double* out_error = nullptr)
{
return verify_range(r2, r1, tols, out_error);
......
......@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT bool verify_args_with_threshold(const std::string& name,
const argument& ref_arg,
const argument& target_arg,
verify::threshold);
verify::tolerance);
MIGRAPHX_EXPORT bool verify_args(const std::string& name,
const argument& target_arg,
......
......@@ -30,7 +30,7 @@ inline namespace MIGRAPHX_INLINE_NS {
bool verify_args_with_threshold(const std::string& name,
const argument& target_arg,
const argument& ref_arg,
verify::threshold tols)
verify::tolerance tols)
{
bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
......@@ -100,7 +100,7 @@ bool verify_args(const std::string& name,
{
double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::threshold tols{rms_tol};
verify::tolerance tols{rms_tol};
return verify_args_with_threshold(name, target_arg, ref_arg, tols);
}
......
......@@ -120,7 +120,7 @@ TEST_CASE(int8_quantization)
EXPECT(migraphx::verify::verify_range_with_threshold(
gpu_result,
migraphx::verify::expected{ref_result},
migraphx::verify::threshold{0.01}));
migraphx::verify::tolerance{0.01}));
else
EXPECT(migraphx::verify::verify_range(gpu_result, ref_result));
}
......
......@@ -1080,7 +1080,7 @@ TEST_CASE(int8_quantization_dot)
EXPECT(migraphx::verify::verify_range_with_threshold(
quant_result,
migraphx::verify::expected{no_quant_result},
migraphx::verify::threshold{0.003}));
migraphx::verify::tolerance{0.003}));
}
}
......
......@@ -79,7 +79,7 @@ void dot_2d_test()
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify::verify_range_with_threshold(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::threshold{9e-6}));
results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
}
TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>)
......@@ -131,7 +131,7 @@ void dot_4d_test()
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range_with_threshold(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::threshold{9e-6}));
results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
}
TEST_CASE_REGISTER(dot_4d_test<float>)
......
......@@ -79,5 +79,5 @@ TEST_CASE(multinomial_test)
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_threshold(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::threshold{0.01}));
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
}
......@@ -70,7 +70,7 @@ TEST_CASE(random_uniform_test)
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range_with_threshold(result_vec,
migraphx::verify::expected{rand_samples},
migraphx::verify::threshold{0.00001}));
migraphx::verify::tolerance{0.00001}));
}
TEST_CASE(random_uniform_int_test)
......
......@@ -1009,7 +1009,9 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range_with_threshold(
last_output_data, migraphx::verify::expected{last_output_data_gold}, migraphx::verify::threshold{0.005}));
last_output_data,
migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005}));
}
TEST_CASE(gru_forward)
......@@ -2985,7 +2987,7 @@ TEST_CASE(gru_fp16)
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range_with_threshold(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::threshold{0.005}));
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
}
TEST_CASE(lstm_forward)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment