Commit ba9361c8 authored by Umang Yadav's avatar Umang Yadav
Browse files

rename verify_range_with_tolerance

parent 27997410
...@@ -254,7 +254,7 @@ bool verify_range(const R1& r1, ...@@ -254,7 +254,7 @@ bool verify_range(const R1& r1,
} }
template <class R1, class R2> template <class R1, class R2>
bool verify_range_with_threshold(const R1& r1, bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2, const expected<R2>& r2,
tolerance tols = tolerance{}, tolerance tols = tolerance{},
double* out_error = nullptr) double* out_error = nullptr)
...@@ -270,7 +270,7 @@ bool verify_range_with_threshold(const R1& r1, ...@@ -270,7 +270,7 @@ bool verify_range_with_threshold(const R1& r1,
// expected argument should be passed as second, but if it is passed as the first by mistake then // expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order // flip the order
template <class R1, class R2> template <class R1, class R2>
bool verify_range_with_threshold(const expected<R1>& r1, bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2, const R2& r2,
tolerance tols = tolerance{}, tolerance tols = tolerance{},
double* out_error = nullptr) double* out_error = nullptr)
......
...@@ -35,7 +35,7 @@ bool verify_args_with_threshold(const std::string& name, ...@@ -35,7 +35,7 @@ bool verify_args_with_threshold(const std::string& name,
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify::verify_range_with_threshold(target, verify::expected{ref}, tols, &error); passed = verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
......
...@@ -117,7 +117,7 @@ TEST_CASE(int8_quantization) ...@@ -117,7 +117,7 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
gpu_result, gpu_result,
migraphx::verify::expected{ref_result}, migraphx::verify::expected{ref_result},
migraphx::verify::tolerance{0.01})); migraphx::verify::tolerance{0.01}));
......
...@@ -1077,7 +1077,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1077,7 +1077,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
quant_result, quant_result,
migraphx::verify::expected{no_quant_result}, migraphx::verify::expected{no_quant_result},
migraphx::verify::tolerance{0.003})); migraphx::verify::tolerance{0.003}));
......
...@@ -78,7 +78,7 @@ void dot_2d_test() ...@@ -78,7 +78,7 @@ void dot_2d_test()
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6})); results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
} }
TEST_CASE_REGISTER(dot_2d_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
...@@ -130,7 +130,7 @@ void dot_4d_test() ...@@ -130,7 +130,7 @@ void dot_4d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6})); results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
} }
......
...@@ -78,6 +78,6 @@ TEST_CASE(multinomial_test) ...@@ -78,6 +78,6 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01})); res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
} }
...@@ -68,7 +68,7 @@ TEST_CASE(random_uniform_test) ...@@ -68,7 +68,7 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range_with_threshold(result_vec, EXPECT(migraphx::verify::verify_range_with_tolerance(result_vec,
migraphx::verify::expected{rand_samples}, migraphx::verify::expected{rand_samples},
migraphx::verify::tolerance{0.00001})); migraphx::verify::tolerance{0.00001}));
} }
......
...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16) ...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
last_output_data, last_output_data,
migraphx::verify::expected{last_output_data_gold}, migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005})); migraphx::verify::tolerance{0.005}));
...@@ -2986,7 +2986,7 @@ TEST_CASE(gru_fp16) ...@@ -2986,7 +2986,7 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range_with_threshold( EXPECT(migraphx::verify::verify_range_with_tolerance(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005})); hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment