"include/vscode:/vscode.git/clone" did not exist on "16b116a99f3aa39eab3b7f9381cb5ad21e409f5c"
Commit 78b8f79b authored by Umang Yadav's avatar Umang Yadav
Browse files

expected working

parent 78142841
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -194,6 +195,25 @@ double get_threshold(const R&, std::size_t tolerance = 80) ...@@ -194,6 +195,25 @@ double get_threshold(const R&, std::size_t tolerance = 80)
return threshold; return threshold;
} }
template <class T>
struct expected
{
expected() = default;
expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, bool verify_range(const R1& r1,
const R2& r2, const R2& r2,
...@@ -208,9 +228,12 @@ bool verify_range(const R1& r1, ...@@ -208,9 +228,12 @@ bool verify_range(const R1& r1,
} }
template <class R1, class R2> template <class R1, class R2>
bool verify_range_with_threshold(const R1& r1, const R2& r2, double threshold, double* out_error = nullptr) bool verify_range_with_threshold(const R1& r1,
const expected<R2>& r2,
double threshold,
double* out_error = nullptr)
{ {
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2.data());
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -28,14 +28,15 @@ namespace migraphx { ...@@ -28,14 +28,15 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool verify_args_with_threshold(const std::string& name, bool verify_args_with_threshold(const std::string& name,
const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
const argument& ref_arg,
double threshold) double threshold)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify::verify_range_with_threshold(ref, target, threshold, &error); passed =
verify::verify_range_with_threshold(target, verify::expected{ref}, threshold, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
...@@ -100,7 +101,7 @@ bool verify_args(const std::string& name, ...@@ -100,7 +101,7 @@ bool verify_args(const std::string& name,
{ {
double threshold = 0.001; double threshold = 0.001;
target_arg.visit([&](auto ta) { threshold = verify::get_threshold(ta, tolerance); }); target_arg.visit([&](auto ta) { threshold = verify::get_threshold(ta, tolerance); });
return verify_args_with_threshold(name, ref_arg, target_arg, threshold); return verify_args_with_threshold(name, target_arg, ref_arg, threshold);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -118,7 +118,7 @@ TEST_CASE(int8_quantization) ...@@ -118,7 +118,7 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range_with_threshold(ref_result, gpu_result, 0.01)); EXPECT(migraphx::verify::verify_range_with_threshold(gpu_result, migraphx::verify::expected{ref_result}, 0.01));
else else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
} }
......
...@@ -1077,7 +1077,8 @@ TEST_CASE(int8_quantization_dot) ...@@ -1077,7 +1077,8 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range_with_threshold(quant_result, no_quant_result, 0.003)); EXPECT(migraphx::verify::verify_range_with_threshold(
quant_result, migraphx::verify::expected{no_quant_result}, 0.003));
} }
} }
......
...@@ -78,5 +78,6 @@ TEST_CASE(multinomial_test) ...@@ -78,5 +78,6 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range_with_threshold(norm, res_norm, 0.01)); EXPECT(migraphx::verify::verify_range_with_threshold(
res_norm, migraphx::verify::expected{norm}, 0.01));
} }
...@@ -68,7 +68,8 @@ TEST_CASE(random_uniform_test) ...@@ -68,7 +68,8 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range_with_threshold(result_vec, rand_samples, 0.00001)); EXPECT(migraphx::verify::verify_range_with_threshold(
result_vec, migraphx::verify::expected{rand_samples}, 0.00001));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
......
...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16) ...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range_with_threshold(last_output_data, last_output_data_gold, 0.005)); EXPECT(migraphx::verify::verify_range_with_threshold(last_output_data, migraphx::verify::expected{last_output_data_gold}, 0.005));
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -2983,7 +2983,7 @@ TEST_CASE(gru_fp16) ...@@ -2983,7 +2983,7 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range_with_threshold(hs_data, hs_data_gold, 0.005)); EXPECT(migraphx::verify::verify_range_with_threshold(hs_data, migraphx::verify::expected{hs_data_gold}, 0.005));
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment