Commit 80cbf756 authored by Umang Yadav's avatar Umang Yadav
Browse files

Use verify_range_with_threshold

parent cc05e9d4
...@@ -208,7 +208,7 @@ bool verify_range(const R1& r1, ...@@ -208,7 +208,7 @@ bool verify_range(const R1& r1,
} }
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double threshold, double* out_error = nullptr) bool verify_range_with_threshold(const R1& r1, const R2& r2, double threshold, double* out_error = nullptr)
{ {
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_error != nullptr)
......
...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name, ...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify::verify_range(ref, target, threshold, &error); passed = verify::verify_range_with_threshold(ref, target, threshold, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
......
...@@ -118,7 +118,7 @@ TEST_CASE(int8_quantization) ...@@ -118,7 +118,7 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 0.01)); EXPECT(migraphx::verify::verify_range_with_threshold(ref_result, gpu_result, 0.01));
else else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
} }
......
...@@ -1077,7 +1077,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1077,7 +1077,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 0.003)); EXPECT(migraphx::verify::verify_range_with_threshold(quant_result, no_quant_result, 0.003));
} }
} }
......
...@@ -78,5 +78,5 @@ TEST_CASE(multinomial_test) ...@@ -78,5 +78,5 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range(norm, res_norm, 0.01)); EXPECT(migraphx::verify::verify_range_with_threshold(norm, res_norm, 0.01));
} }
...@@ -68,7 +68,7 @@ TEST_CASE(random_uniform_test) ...@@ -68,7 +68,7 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 0.00001)); EXPECT(migraphx::verify::verify_range_with_threshold(result_vec, rand_samples, 0.00001));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
......
...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16) ...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold, 0.005)); EXPECT(migraphx::verify::verify_range_with_threshold(last_output_data, last_output_data_gold, 0.005));
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -2983,7 +2983,7 @@ TEST_CASE(gru_fp16) ...@@ -2983,7 +2983,7 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 0.005)); EXPECT(migraphx::verify::verify_range_with_threshold(hs_data, hs_data_gold, 0.005));
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment