Unverified Commit 69d8d789 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add options to set tolerances inside MIGraphX driver (#2213)

MIGraphX verification by default uses normalized RMS error as the basis for the verification.  This change adds some logic to allow migraphx to do "np.allclose" type of elementwise verification using atol and rtol.

Commit also includes changes to consistently pass "gold" or "expected" results as the second argument for "verify_range()" calls.  Default RMS tolerance inside driver is set to 0.001 which IMO is high for FP32 compared to what we had earlier. Need better defaults
parent e12032fb
...@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro ...@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, sol)); EXPECT(migraphx::verify::verify_rms_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU. An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target. By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
......
...@@ -50,9 +50,17 @@ Runs reference and CPU or GPU implementations and checks outputs for consistency ...@@ -50,9 +50,17 @@ Runs reference and CPU or GPU implementations and checks outputs for consistency
.. include:: ./driver/compile.rst .. include:: ./driver/compile.rst
.. option:: --tolerance [double] .. option:: --rms-tol [double]
Tolerance for errors (Default: 80) Tolerance for RMS error (Default: 0.001)
.. option:: --atol [double]
Tolerance for elementwise absolute difference (Default: 0.001)
.. option:: --rtol [double]
Tolerance for elementwise relative difference (Default: 0.001)
.. option:: -i, --per-instruction .. option:: -i, --per-instruction
......
...@@ -55,7 +55,9 @@ See below for a comprehensive list of commands and option arguments, as well as ...@@ -55,7 +55,9 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel | | --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 | | --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 | | --int8 | Quantize for int8 |
| --tolerance | Tolerance for errors | | --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
| --per-instruction \| -i | Verify each instruction | | --per-instruction \| -i | Verify each instruction |
| --reduce \| -r | Reduce program and verify | | --reduce \| -r | Reduce program and verify |
| --iterations \| -n | Number of iterations to run for perf report | | --iterations \| -n | Number of iterations to run for perf report |
......
...@@ -536,13 +536,19 @@ struct params : command<params> ...@@ -536,13 +536,19 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
compiler c; compiler c;
double tolerance = 80; migraphx::verify::tolerance tols;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
c.parse(ap); c.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(tols.rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)"));
ap(tols.atol,
{"--atol"},
ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)"));
ap(tols.rtol,
{"--rtol"},
ap.help("Tolerance for the elementwise relative difference (Default: 0.001)"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
...@@ -567,15 +573,15 @@ struct verify : command<verify> ...@@ -567,15 +573,15 @@ struct verify : command<verify>
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, c.co, quantize, tolerance); verify_instructions(p, t, c.co, quantize, tols);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, c.co, quantize, m, tolerance); verify_reduced_program(p, t, c.co, quantize, m, tols);
} }
else else
{ {
verify_program(c.l.file, p, t, c.co, quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, quantize, m, tols);
} }
} }
}; };
......
...@@ -77,24 +77,24 @@ void verify_program(const std::string& name, ...@@ -77,24 +77,24 @@ void verify_program(const std::string& name,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
auto x = run_ref(p, inputs); auto ref_outs = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs); auto target_outs = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size(); std::size_t output_num = ref_outs.size();
for(std::size_t i = 0; i < output_num; ++i) for(std::size_t i = 0; i < output_num; ++i)
{ {
if(x[i].get_shape().type() != y[i].get_shape().type() or if(ref_outs[i].get_shape().type() != target_outs[i].get_shape().type() or
x[i].get_shape().lens() != y[i].get_shape().lens()) ref_outs[i].get_shape().lens() != target_outs[i].get_shape().lens())
{ {
std::cout << "FAILED: " << name << std::endl; std::cout << "FAILED: " << name << std::endl;
std::cout << "Shape mismatch {" << x[i].get_shape() << "} != {" << y[i].get_shape() std::cout << "Shape mismatch {" << ref_outs[i].get_shape() << "} != {"
<< "}" << std::endl; << target_outs[i].get_shape() << "}" << std::endl;
} }
else else
{ {
verify_args(name, x[i], y[i], tolerance); verify_args(name, target_outs[i], verify::expected{ref_outs[i]}, tols);
} }
} }
} }
...@@ -103,7 +103,7 @@ void verify_instructions(const program& prog, ...@@ -103,7 +103,7 @@ void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize, precision quantize,
double tolerance) verify::tolerance tols)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog)) for(auto&& ins : (*mm_prog))
...@@ -134,8 +134,7 @@ void verify_instructions(const program& prog, ...@@ -134,8 +134,7 @@ void verify_instructions(const program& prog,
{ {
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program( verify_program(ins.name(), p, t, options, quantize, create_param_map(p, false), tols);
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
} }
catch(...) catch(...)
{ {
...@@ -151,7 +150,7 @@ void verify_reduced(program p, ...@@ -151,7 +150,7 @@ void verify_reduced(program p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n); auto last = std::prev(mm->end(), n);
...@@ -160,7 +159,7 @@ void verify_reduced(program p, ...@@ -160,7 +159,7 @@ void verify_reduced(program p,
std::cout << p << std::endl; std::cout << p << std::endl;
try try
{ {
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); verify_program(std::to_string(n), p, t, options, quantize, inputs, tols);
} }
catch(const std::exception& e) catch(const std::exception& e)
{ {
...@@ -174,7 +173,7 @@ void verify_reduced_program(const program& p, ...@@ -174,7 +173,7 @@ void verify_reduced_program(const program& p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
...@@ -187,7 +186,7 @@ void verify_reduced_program(const program& p, ...@@ -187,7 +186,7 @@ void verify_reduced_program(const program& p,
std::cout << "Skip: " << i << std::endl; std::cout << "Skip: " << i << std::endl;
continue; continue;
} }
verify_reduced(p, i, t, options, quantize, inputs, tolerance); verify_reduced(p, i, t, options, quantize, inputs, tols);
} }
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "precision.hpp" #include "precision.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -37,18 +38,18 @@ void verify_program(const std::string& name, ...@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify { namespace verify {
...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2) ...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max(); return std::numeric_limits<range_value<R1>>::max();
} }
template <class R>
double get_rms_tol(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template <class T>
struct expected
{
expected() = default;
explicit expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) bool allclose(const R1& r1, const R2& r2, tolerance tols)
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
auto idx = mismatch_idx(r1, r2, [&](auto x, auto y) {
return abs_diff(double(x), double(y)) < tols.atol + tols.rtol * std::abs(double(y));
});
return idx >= range_distance(r1);
}
return false;
}
template <class R1, class R2>
bool verify_rms_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_rms_error = nullptr)
{
double threshold = get_rms_tol(r1, tolerance);
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_rms_error != nullptr)
*out_error = error; *out_rms_error = error;
return error <= threshold; return error <= threshold;
} }
template <class R1, class R2>
bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool ewise_verify = true;
if(enabled(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE{}))
{
ewise_verify = allclose(r1, r2.data(), tols);
}
if(out_rms_error != nullptr)
*out_rms_error = rms_error;
return rms_error <= tols.rms_tol and ewise_verify;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template <class R1, class R2>
bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
return verify_rms_range(r2, r1, tols, out_rms_error);
}
} // namespace verify } // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -31,11 +31,15 @@ ...@@ -31,11 +31,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT MIGRAPHX_EXPORT bool verify_args(const std::string& name,
bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
double tolerance = 80); const verify::expected<argument>& ref_arg,
verify::tolerance);
MIGRAPHX_EXPORT bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -28,19 +28,20 @@ namespace migraphx { ...@@ -28,19 +28,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
double tolerance) const verify::expected<argument>& ref_arg,
verify::tolerance tols)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double error; double rms_error;
passed = verify::verify_range(ref, target, tolerance, &error); passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
std::cout << "FAILED: " << name << std::endl; std::cout << "FAILED: " << name << std::endl;
std::cout << "error: " << error << std::endl; std::cout << "RMS Error: " << rms_error << std::endl;
if(ref.size() < 32) if(ref.size() < 32)
std::cout << "ref:" << ref << std::endl; std::cout << "ref:" << ref << std::endl;
if(target.size() < 32) if(target.size() < 32)
...@@ -93,5 +94,16 @@ bool verify_args(const std::string& name, ...@@ -93,5 +94,16 @@ bool verify_args(const std::string& name,
return passed; return passed;
} }
bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance)
{
double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::tolerance tols{rms_tol};
return verify_args(name, target_arg, ref_arg, tols);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test) ...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy) ...@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy)
migraphx::parameter_map pp; migraphx::parameter_map pp;
std::vector<float> a_vec(ss.elements(), -1); std::vector<float> a_vec(ss.elements(), -1);
std::vector<float> b_vec(ss.elements(), 2); std::vector<float> b_vec(ss.elements(), 2);
std::vector<float> c_vec(ss.elements(), 0);
pp["a"] = migraphx::argument(ss, a_vec.data()); pp["a"] = migraphx::argument(ss, a_vec.data());
pp["b"] = migraphx::argument(ss, b_vec.data()); pp["b"] = migraphx::argument(ss, b_vec.data());
std::vector<float> gpu_result; std::vector<float> gpu_result;
...@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy) ...@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy)
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> results_vector(ss.elements(), -1); std::vector<float> results_vector(ss.elements(), -1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(c_vec, results_vector)); std::vector<float> gold_vec(ss.elements(), 0);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_vec));
} }
TEST_CASE(arguments_lifetime) TEST_CASE(arguments_lifetime)
......
...@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir) ...@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir)
auto inputs = generate_params(ref); auto inputs = generate_params(ref);
auto mlir = create_program_from_mlir(mmlir); auto mlir = create_program_from_mlir(mmlir);
return migraphx::verify_args("mlir", run_ref(ref, inputs), run_gpu(mlir, inputs)); return migraphx::verify_args_with_tolerance(
"mlir", run_gpu(mlir, inputs), migraphx::verify::expected{run_ref(ref, inputs)});
} }
TEST_CASE(conv) TEST_CASE(conv)
......
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
TEST_CASE(gpu_target_copy) TEST_CASE(gpu_target_copy)
{ {
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::target ref_t = migraphx::make_target("ref");
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L); auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L);
...@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy) ...@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final; std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify::verify_range(val_orig, val_final)); EXPECT(migraphx::verify::verify_rms_range(val_orig, val_final));
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization)
...@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization) ...@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5)); EXPECT(migraphx::verify::verify_range_with_tolerance(
gpu_result,
migraphx::verify::expected{ref_result},
migraphx::verify::tolerance{0.01}));
else else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
} }
} }
......
...@@ -47,7 +47,7 @@ TEST_CASE(averagepool_notset_test) ...@@ -47,7 +47,7 @@ TEST_CASE(averagepool_notset_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {12}; std::vector<float> gold = {12};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(averagepool_nt_cip_test) TEST_CASE(averagepool_nt_cip_test)
...@@ -65,7 +65,7 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -65,7 +65,7 @@ TEST_CASE(averagepool_nt_cip_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {8.33333}; std::vector<float> gold = {8.33333};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_flat_test) TEST_CASE(batch_norm_flat_test)
...@@ -111,7 +111,7 @@ TEST_CASE(batch_norm_flat_test) ...@@ -111,7 +111,7 @@ TEST_CASE(batch_norm_flat_test)
0.43305403, 0.43305403,
0.4408022, 0.4408022,
0.42019472}; 0.42019472};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_rank_2_test) TEST_CASE(batch_norm_rank_2_test)
...@@ -148,7 +148,7 @@ TEST_CASE(batch_norm_rank_2_test) ...@@ -148,7 +148,7 @@ TEST_CASE(batch_norm_rank_2_test)
9.89948504, 9.89948504,
9.89948504, 9.89948504,
12.72790933}; 12.72790933};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_1d_test) TEST_CASE(batch_norm_1d_test)
...@@ -184,7 +184,7 @@ TEST_CASE(batch_norm_1d_test) ...@@ -184,7 +184,7 @@ TEST_CASE(batch_norm_1d_test)
0.4927, 0.771, -1.956, -2.123, -0.664, -0.583, -0.7207, -0.5127}; 0.4927, 0.771, -1.956, -2.123, -0.664, -0.583, -0.7207, -0.5127};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()}; std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_2d_test) TEST_CASE(batch_norm_2d_test)
...@@ -250,7 +250,7 @@ TEST_CASE(batch_norm_2d_test) ...@@ -250,7 +250,7 @@ TEST_CASE(batch_norm_2d_test)
-2.76707697e+00, 1.47579327e+01, 4.94736385e+00, 2.68847847e+01, -6.49254417e+00, -2.76707697e+00, 1.47579327e+01, 4.94736385e+00, 2.68847847e+01, -6.49254417e+00,
1.94286156e+00, -7.19223642e+00, -3.70413971e+00, -4.04303551e-01, -1.01827660e+01, 1.94286156e+00, -7.19223642e+00, -3.70413971e+00, -4.04303551e-01, -1.01827660e+01,
1.49476433e+00}; 1.49476433e+00};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_3d_test) TEST_CASE(batch_norm_3d_test)
...@@ -292,7 +292,7 @@ TEST_CASE(batch_norm_3d_test) ...@@ -292,7 +292,7 @@ TEST_CASE(batch_norm_3d_test)
6.098, 11.03, 2.81, 2.81, 2.81, 12.125, 3.143, 8.53, 17.52, 4.938, 15.71, 6.098, 11.03, 2.81, 2.81, 2.81, 12.125, 3.143, 8.53, 17.52, 4.938, 15.71,
1.347, 4.938, 1.167, 6.098, 12.67, 12.67, 4.453, 4.453, -0.4768, 12.67}; 1.347, 4.938, 1.167, 6.098, 12.67, 12.67, 4.453, 4.453, -0.4768, 12.67};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()}; std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(celu_verify_test) TEST_CASE(celu_verify_test)
...@@ -309,12 +309,12 @@ TEST_CASE(celu_verify_test) ...@@ -309,12 +309,12 @@ TEST_CASE(celu_verify_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct(6); std::vector<float> gold(6);
float alpha = 0.5; float alpha = 0.5;
std::transform(data.begin(), data.end(), correct.begin(), [&](auto x) { std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha)); return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha));
}); });
EXPECT(migraphx::verify::verify_range(result_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(clip_args_type_mismatch) TEST_CASE(clip_args_type_mismatch)
...@@ -330,7 +330,7 @@ TEST_CASE(clip_args_type_mismatch) ...@@ -330,7 +330,7 @@ TEST_CASE(clip_args_type_mismatch)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7}; std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(depthtospace_simple_test) TEST_CASE(depthtospace_simple_test)
...@@ -348,7 +348,7 @@ TEST_CASE(depthtospace_simple_test) ...@@ -348,7 +348,7 @@ TEST_CASE(depthtospace_simple_test)
std::vector<float> gold = {0, 12, 1, 13, 2, 14, 24, 36, 25, 37, 26, 38, 3, 15, 4, 16, std::vector<float> gold = {0, 12, 1, 13, 2, 14, 24, 36, 25, 37, 26, 38, 3, 15, 4, 16,
5, 17, 27, 39, 28, 40, 29, 41, 6, 18, 7, 19, 8, 20, 30, 42, 5, 17, 27, 39, 28, 40, 29, 41, 6, 18, 7, 19, 8, 20, 30, 42,
31, 43, 32, 44, 9, 21, 10, 22, 11, 23, 33, 45, 34, 46, 35, 47}; 31, 43, 32, 44, 9, 21, 10, 22, 11, 23, 33, 45, 34, 46, 35, 47};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(spacetodepth_simple_test) TEST_CASE(spacetodepth_simple_test)
...@@ -366,7 +366,7 @@ TEST_CASE(spacetodepth_simple_test) ...@@ -366,7 +366,7 @@ TEST_CASE(spacetodepth_simple_test)
std::vector<float> gold = {0, 2, 4, 12, 14, 16, 24, 26, 28, 36, 38, 40, 1, 3, 5, 13, std::vector<float> gold = {0, 2, 4, 12, 14, 16, 24, 26, 28, 36, 38, 40, 1, 3, 5, 13,
15, 17, 25, 27, 29, 37, 39, 41, 6, 8, 10, 18, 20, 22, 30, 32, 15, 17, 25, 27, 29, 37, 39, 41, 6, 8, 10, 18, 20, 22, 30, 32,
34, 42, 44, 46, 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47}; 34, 42, 44, 46, 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(spacetodepth_depthtospace_test) TEST_CASE(spacetodepth_depthtospace_test)
...@@ -374,11 +374,11 @@ TEST_CASE(spacetodepth_depthtospace_test) ...@@ -374,11 +374,11 @@ TEST_CASE(spacetodepth_depthtospace_test)
// space to depth // space to depth
auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx"); auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx");
p1.compile(migraphx::make_target("ref")); p1.compile(migraphx::make_target("ref"));
std::vector<float> data_in(48); std::vector<float> gold_data_in(48);
std::iota(std::begin(data_in), std::end(data_in), 0); std::iota(std::begin(gold_data_in), std::end(gold_data_in), 0);
migraphx::shape s_x_1{migraphx::shape::float_type, {1, 2, 4, 6}}; migraphx::shape s_x_1{migraphx::shape::float_type, {1, 2, 4, 6}};
migraphx::parameter_map pp1; migraphx::parameter_map pp1;
pp1["x"] = migraphx::argument(s_x_1, data_in.data()); pp1["x"] = migraphx::argument(s_x_1, gold_data_in.data());
auto result1 = p1.eval(pp1).back(); auto result1 = p1.eval(pp1).back();
// depth to space // depth to space
auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx"); auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx");
...@@ -388,7 +388,7 @@ TEST_CASE(spacetodepth_depthtospace_test) ...@@ -388,7 +388,7 @@ TEST_CASE(spacetodepth_depthtospace_test)
auto result2 = p2.eval(pp2).back(); auto result2 = p2.eval(pp2).back();
std::vector<float> result_vector2; std::vector<float> result_vector2;
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vector2, data_in)); EXPECT(migraphx::verify::verify_rms_range(result_vector2, gold_data_in));
} }
TEST_CASE(eyelike_verify_test) TEST_CASE(eyelike_verify_test)
...@@ -405,8 +405,8 @@ TEST_CASE(eyelike_verify_test) ...@@ -405,8 +405,8 @@ TEST_CASE(eyelike_verify_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.}; std::vector<float> gold_eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.};
EXPECT(migraphx::verify::verify_range(result_vector, eyelike_mat)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_eyelike_mat));
} }
TEST_CASE(eyelike_verify_negk_test) TEST_CASE(eyelike_verify_negk_test)
...@@ -423,8 +423,8 @@ TEST_CASE(eyelike_verify_negk_test) ...@@ -423,8 +423,8 @@ TEST_CASE(eyelike_verify_negk_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.}; std::vector<float> gold_eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(result_vector, eyelike_mat)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_eyelike_mat));
} }
TEST_CASE(gather_elements) TEST_CASE(gather_elements)
...@@ -447,7 +447,7 @@ TEST_CASE(gather_elements) ...@@ -447,7 +447,7 @@ TEST_CASE(gather_elements)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375}; std::vector<float> gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
...@@ -491,7 +491,7 @@ TEST_CASE(gemm_test) ...@@ -491,7 +491,7 @@ TEST_CASE(gemm_test)
0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702, 0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702,
0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703, 0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703,
1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242}; 1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(gemm_half_test) TEST_CASE(gemm_half_test)
...@@ -535,7 +535,7 @@ TEST_CASE(gemm_half_test) ...@@ -535,7 +535,7 @@ TEST_CASE(gemm_half_test)
2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155, 2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155,
1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394}; 1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()}; std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(greaterorequal_test) TEST_CASE(greaterorequal_test)
...@@ -556,7 +556,7 @@ TEST_CASE(greaterorequal_test) ...@@ -556,7 +556,7 @@ TEST_CASE(greaterorequal_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, 0.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(hardsigmoid_verify_test) TEST_CASE(hardsigmoid_verify_test)
...@@ -580,7 +580,7 @@ TEST_CASE(hardsigmoid_verify_test) ...@@ -580,7 +580,7 @@ TEST_CASE(hardsigmoid_verify_test)
std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) { std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, std::min(x * alpha + beta, 1.0f)); return std::max(0.0f, std::min(x * alpha + beta, 1.0f));
}); });
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_else_test) TEST_CASE(if_else_test)
...@@ -602,7 +602,7 @@ TEST_CASE(if_else_test) ...@@ -602,7 +602,7 @@ TEST_CASE(if_else_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683}; std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_else_test_inlined) TEST_CASE(if_else_test_inlined)
...@@ -621,7 +621,7 @@ TEST_CASE(if_else_test_inlined) ...@@ -621,7 +621,7 @@ TEST_CASE(if_else_test_inlined)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472}; std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_then_test) TEST_CASE(if_then_test)
...@@ -644,7 +644,7 @@ TEST_CASE(if_then_test) ...@@ -644,7 +644,7 @@ TEST_CASE(if_then_test)
// onnx adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_then_test_inlined) TEST_CASE(if_then_test_inlined)
...@@ -663,7 +663,7 @@ TEST_CASE(if_then_test_inlined) ...@@ -663,7 +663,7 @@ TEST_CASE(if_then_test_inlined)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_literal_test) TEST_CASE(if_literal_test)
...@@ -688,14 +688,14 @@ TEST_CASE(if_literal_test) ...@@ -688,14 +688,14 @@ TEST_CASE(if_literal_test)
{ {
auto result_vector = run_prog(true); auto result_vector = run_prog(true);
std::vector<float> gold = {1, 2, 3, 4, 5}; std::vector<float> gold = {1, 2, 3, 4, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
// else branch // else branch
{ {
auto result_vector = run_prog(false); auto result_vector = run_prog(false);
std::vector<float> gold = {5, 4, 3, 2, 1}; std::vector<float> gold = {5, 4, 3, 2, 1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
} }
...@@ -726,7 +726,7 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test) ...@@ -726,7 +726,7 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
std::vector<float> gold = { std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125}; 1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_then_else_multi_output_shapes_test) TEST_CASE(if_then_else_multi_output_shapes_test)
...@@ -757,7 +757,7 @@ TEST_CASE(if_then_else_multi_output_shapes_test) ...@@ -757,7 +757,7 @@ TEST_CASE(if_then_else_multi_output_shapes_test)
std::vector<float> gold = { std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125}; 1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
...@@ -789,14 +789,14 @@ TEST_CASE(if_pl_test) ...@@ -789,14 +789,14 @@ TEST_CASE(if_pl_test)
{ {
auto result_vector = run_prog(true); auto result_vector = run_prog(true);
std::vector<float> gold = {2, 3, 4, 5, 6, 7}; std::vector<float> gold = {2, 3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
// else branch // else branch
{ {
auto result_vector = run_prog(false); auto result_vector = run_prog(false);
std::vector<float> gold = {1, 2, 3, 4, 5, 6}; std::vector<float> gold = {1, 2, 3, 4, 5, 6};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
} }
...@@ -835,8 +835,8 @@ TEST_CASE(if_tuple_test) ...@@ -835,8 +835,8 @@ TEST_CASE(if_tuple_test)
auto results = run_prog(true); auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f); std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f); std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify::verify_range(results.at(0), gold0)); EXPECT(migraphx::verify::verify_rms_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_range(results.at(1), gold1)); EXPECT(migraphx::verify::verify_rms_range(results.at(1), gold1));
} }
// else branch // else branch
...@@ -844,8 +844,8 @@ TEST_CASE(if_tuple_test) ...@@ -844,8 +844,8 @@ TEST_CASE(if_tuple_test)
auto results = run_prog(false); auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f); std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f); std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify::verify_range(results.at(0), gold0)); EXPECT(migraphx::verify::verify_rms_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_range(results.at(1), gold1)); EXPECT(migraphx::verify::verify_rms_range(results.at(1), gold1));
} }
} }
...@@ -876,7 +876,7 @@ TEST_CASE(instance_norm_test) ...@@ -876,7 +876,7 @@ TEST_CASE(instance_norm_test)
2.54919, 2.54919,
3.32379, 3.32379,
4.09838}; 4.09838};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(instance_norm_dyn_batch_test) TEST_CASE(instance_norm_dyn_batch_test)
...@@ -918,7 +918,7 @@ TEST_CASE(instance_norm_dyn_batch_test) ...@@ -918,7 +918,7 @@ TEST_CASE(instance_norm_dyn_batch_test)
2.54919, 2.54919,
3.32379, 3.32379,
4.09838}; 4.09838};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(instance_norm_3d_test) TEST_CASE(instance_norm_3d_test)
...@@ -947,7 +947,7 @@ TEST_CASE(instance_norm_3d_test) ...@@ -947,7 +947,7 @@ TEST_CASE(instance_norm_3d_test)
3.18218, 3.18218,
4.05505}; 4.05505};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(lessorequal_test) TEST_CASE(lessorequal_test)
...@@ -968,7 +968,7 @@ TEST_CASE(lessorequal_test) ...@@ -968,7 +968,7 @@ TEST_CASE(lessorequal_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 0, 1}; std::vector<float> gold = {1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(lpnormalization_1norm) TEST_CASE(lpnormalization_1norm)
...@@ -996,7 +996,7 @@ TEST_CASE(lpnormalization_1norm) ...@@ -996,7 +996,7 @@ TEST_CASE(lpnormalization_1norm)
3.f / 7.f, 3.f / 7.f,
0.f, 0.f,
0.f}; 0.f};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(lpnormalization_2norm) TEST_CASE(lpnormalization_2norm)
...@@ -1012,7 +1012,7 @@ TEST_CASE(lpnormalization_2norm) ...@@ -1012,7 +1012,7 @@ TEST_CASE(lpnormalization_2norm)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct{0.f, std::vector<float> gold{0.f,
2.f / 3.f, 2.f / 3.f,
-2.f / 3.f, -2.f / 3.f,
1.f / 3.f, 1.f / 3.f,
...@@ -1024,7 +1024,7 @@ TEST_CASE(lpnormalization_2norm) ...@@ -1024,7 +1024,7 @@ TEST_CASE(lpnormalization_2norm)
3.f / 5.f, 3.f / 5.f,
0.f, 0.f,
0.f}; 0.f};
EXPECT(migraphx::verify::verify_range(result_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mean_broadcast_test) TEST_CASE(mean_broadcast_test)
...@@ -1055,7 +1055,7 @@ TEST_CASE(mean_broadcast_test) ...@@ -1055,7 +1055,7 @@ TEST_CASE(mean_broadcast_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(24, 3); std::vector<float> gold(24, 3);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mean_test) TEST_CASE(mean_test)
...@@ -1082,7 +1082,7 @@ TEST_CASE(mean_test) ...@@ -1082,7 +1082,7 @@ TEST_CASE(mean_test)
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data; const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data;
std::vector<double> gold(num_elms, mean); std::vector<double> gold(num_elms, mean);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mean_integral_test) TEST_CASE(mean_integral_test)
...@@ -1109,7 +1109,7 @@ TEST_CASE(mean_integral_test) ...@@ -1109,7 +1109,7 @@ TEST_CASE(mean_integral_test)
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data; const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data;
std::vector<int> gold(num_elms, mean); std::vector<int> gold(num_elms, mean);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test) TEST_CASE(mod_test)
...@@ -1136,7 +1136,7 @@ TEST_CASE(mod_test) ...@@ -1136,7 +1136,7 @@ TEST_CASE(mod_test)
std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2,
5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5}; 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test_different_types) TEST_CASE(mod_test_different_types)
...@@ -1164,7 +1164,7 @@ TEST_CASE(mod_test_different_types) ...@@ -1164,7 +1164,7 @@ TEST_CASE(mod_test_different_types)
std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2,
5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5}; 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test_fmod) TEST_CASE(mod_test_fmod)
...@@ -1193,7 +1193,7 @@ TEST_CASE(mod_test_fmod) ...@@ -1193,7 +1193,7 @@ TEST_CASE(mod_test_fmod)
10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2, 10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2,
7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1}; 7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test_fmod_different_types) TEST_CASE(mod_test_fmod_different_types)
...@@ -1223,7 +1223,7 @@ TEST_CASE(mod_test_fmod_different_types) ...@@ -1223,7 +1223,7 @@ TEST_CASE(mod_test_fmod_different_types)
10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2, 10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2,
7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1}; 7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(nonzero_test) TEST_CASE(nonzero_test)
...@@ -1242,7 +1242,7 @@ TEST_CASE(nonzero_test) ...@@ -1242,7 +1242,7 @@ TEST_CASE(nonzero_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 1, 0, 0, 1, 0, 0}; std::vector<float> gold = {0, 0, 1, 0, 0, 1, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_downsample_f_test) TEST_CASE(resize_downsample_f_test)
...@@ -1263,7 +1263,7 @@ TEST_CASE(resize_downsample_f_test) ...@@ -1263,7 +1263,7 @@ TEST_CASE(resize_downsample_f_test)
std::vector<float> gold = {0.0f, 3.0f}; std::vector<float> gold = {0.0f, 3.0f};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_upsample_linear_ac_test) TEST_CASE(resize_upsample_linear_ac_test)
...@@ -1298,7 +1298,7 @@ TEST_CASE(resize_upsample_linear_ac_test) ...@@ -1298,7 +1298,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
11.0f / 3, 11.0f / 3,
4}; 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_upsample_linear_test) TEST_CASE(resize_upsample_linear_test)
...@@ -1319,7 +1319,7 @@ TEST_CASE(resize_upsample_linear_test) ...@@ -1319,7 +1319,7 @@ TEST_CASE(resize_upsample_linear_test)
std::vector<float> gold = { std::vector<float> gold = {
1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4}; 1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_upsample_pf_test) TEST_CASE(resize_upsample_pf_test)
...@@ -1340,7 +1340,7 @@ TEST_CASE(resize_upsample_pf_test) ...@@ -1340,7 +1340,7 @@ TEST_CASE(resize_upsample_pf_test)
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}; 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(reversesequence_4D_verify_test) TEST_CASE(reversesequence_4D_verify_test)
...@@ -1361,7 +1361,7 @@ TEST_CASE(reversesequence_4D_verify_test) ...@@ -1361,7 +1361,7 @@ TEST_CASE(reversesequence_4D_verify_test)
std::vector<float> gold = { std::vector<float> gold = {
8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0}; 8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(reversesequence_batch_verify_test) TEST_CASE(reversesequence_batch_verify_test)
...@@ -1382,7 +1382,7 @@ TEST_CASE(reversesequence_batch_verify_test) ...@@ -1382,7 +1382,7 @@ TEST_CASE(reversesequence_batch_verify_test)
std::vector<float> gold = { std::vector<float> gold = {
0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0}; 0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(reversesequence_time_verify_test) TEST_CASE(reversesequence_time_verify_test)
...@@ -1403,7 +1403,7 @@ TEST_CASE(reversesequence_time_verify_test) ...@@ -1403,7 +1403,7 @@ TEST_CASE(reversesequence_time_verify_test)
std::vector<float> gold = { std::vector<float> gold = {
3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0}; 3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(selu_test) TEST_CASE(selu_test)
...@@ -1423,7 +1423,7 @@ TEST_CASE(selu_test) ...@@ -1423,7 +1423,7 @@ TEST_CASE(selu_test)
std::vector<float> gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6}; std::vector<float> gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(size_verify_test) TEST_CASE(size_verify_test)
...@@ -1457,7 +1457,7 @@ TEST_CASE(slice_test) ...@@ -1457,7 +1457,7 @@ TEST_CASE(slice_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2, 3}; std::vector<float> gold = {2, 3};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(slice_5arg_test) TEST_CASE(slice_5arg_test)
...@@ -1477,7 +1477,7 @@ TEST_CASE(slice_5arg_test) ...@@ -1477,7 +1477,7 @@ TEST_CASE(slice_5arg_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18}; std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(slice_reverse_test) TEST_CASE(slice_reverse_test)
...@@ -1497,7 +1497,7 @@ TEST_CASE(slice_reverse_test) ...@@ -1497,7 +1497,7 @@ TEST_CASE(slice_reverse_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16}; std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(slice_step_test) TEST_CASE(slice_step_test)
...@@ -1517,7 +1517,7 @@ TEST_CASE(slice_step_test) ...@@ -1517,7 +1517,7 @@ TEST_CASE(slice_step_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 12}; std::vector<float> gold = {14, 12};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(softplus_test) TEST_CASE(softplus_test)
...@@ -1538,7 +1538,7 @@ TEST_CASE(softplus_test) ...@@ -1538,7 +1538,7 @@ TEST_CASE(softplus_test)
std::transform( std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); }); data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); });
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(softsign_test) TEST_CASE(softsign_test)
...@@ -1559,7 +1559,7 @@ TEST_CASE(softsign_test) ...@@ -1559,7 +1559,7 @@ TEST_CASE(softsign_test)
std::transform( std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); }); data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); });
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(upsample_test) TEST_CASE(upsample_test)
...@@ -1578,7 +1578,7 @@ TEST_CASE(upsample_test) ...@@ -1578,7 +1578,7 @@ TEST_CASE(upsample_test)
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}; 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(where_test) TEST_CASE(where_test)
...@@ -1620,7 +1620,7 @@ TEST_CASE(where_test) ...@@ -1620,7 +1620,7 @@ TEST_CASE(where_test)
2.0f, 2.0f,
1.0f, 1.0f,
2.0f}; 2.0f};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p) std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
...@@ -1645,7 +1645,7 @@ TEST_CASE(trilu_test) ...@@ -1645,7 +1645,7 @@ TEST_CASE(trilu_test)
std::vector<float> gold = {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12}; std::vector<float> gold = {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_batch_diff_k_test) TEST_CASE(trilu_batch_diff_k_test)
...@@ -1656,7 +1656,7 @@ TEST_CASE(trilu_batch_diff_k_test) ...@@ -1656,7 +1656,7 @@ TEST_CASE(trilu_batch_diff_k_test)
std::vector<float> gold = {0, 0, 3, 0, 0, 0, 0, 0, 9, 0, 0, 0}; std::vector<float> gold = {0, 0, 3, 0, 0, 0, 0, 0, 9, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_lower_test) TEST_CASE(trilu_lower_test)
...@@ -1667,7 +1667,7 @@ TEST_CASE(trilu_lower_test) ...@@ -1667,7 +1667,7 @@ TEST_CASE(trilu_lower_test)
std::vector<float> gold = {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0}; std::vector<float> gold = {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_out_k_test) TEST_CASE(trilu_out_k_test)
...@@ -1678,7 +1678,7 @@ TEST_CASE(trilu_out_k_test) ...@@ -1678,7 +1678,7 @@ TEST_CASE(trilu_out_k_test)
std::vector<float> gold(12, 0); std::vector<float> gold(12, 0);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_row_one_test) TEST_CASE(trilu_row_one_test)
...@@ -1689,7 +1689,7 @@ TEST_CASE(trilu_row_one_test) ...@@ -1689,7 +1689,7 @@ TEST_CASE(trilu_row_one_test)
std::vector<float> gold = {0, 2, 3, 4}; std::vector<float> gold = {0, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy) ...@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result; std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result); run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify::verify_range(ref_result, orig_result)); EXPECT(migraphx::verify::verify_rms_range(ref_result, orig_result));
} }
} }
...@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot) ...@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000)); EXPECT(migraphx::verify::verify_range_with_tolerance(
quant_result,
migraphx::verify::expected{no_quant_result},
migraphx::verify::tolerance{0.003}));
} }
} }
...@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result); run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify::verify_rms_range(quant_result, no_quant_result));
} }
} }
...@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture) ...@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(vec, cap_vec)); EXPECT(migraphx::verify::verify_rms_range(vec, cap_vec));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -42,7 +42,7 @@ TEST_CASE(abs_test) ...@@ -42,7 +42,7 @@ TEST_CASE(abs_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4}; std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(abs_dyn_test) TEST_CASE(abs_dyn_test)
...@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test) ...@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4}; std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(acos_test) ...@@ -45,7 +45,7 @@ TEST_CASE(acos_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(acos_dyn_test) TEST_CASE(acos_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(acosh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(acosh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(acosh_dyn_test) TEST_CASE(acosh_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test) ...@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(add_multibroadcast_test) TEST_CASE(add_multibroadcast_test)
...@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test) ...@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(add_test) TEST_CASE(add_test)
...@@ -91,7 +91,7 @@ TEST_CASE(add_test) ...@@ -91,7 +91,7 @@ TEST_CASE(add_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(add_dyn_test) TEST_CASE(add_dyn_test)
...@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test) ...@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(fp16_test) TEST_CASE(fp16_test)
...@@ -134,7 +134,7 @@ TEST_CASE(fp16_test) ...@@ -134,7 +134,7 @@ TEST_CASE(fp16_test)
std::vector<migraphx::half> results_vector(1); std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c}; std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(fp32_fp16_test) TEST_CASE(fp32_fp16_test)
...@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test) ...@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> res; std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res, gold_res)); EXPECT(migraphx::verify::verify_rms_range(res, gold_res));
}; };
test_case({"all"}); test_case({"all"});
......
...@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0) ...@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_1) TEST_CASE(argmax_test_1)
...@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1) ...@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_2) TEST_CASE(argmax_test_2)
...@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2) ...@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_neg_2) TEST_CASE(argmax_test_neg_2)
...@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2) ...@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_dyn_test) TEST_CASE(argmax_dyn_test)
...@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test) ...@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1}; std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_nonstd_shape) TEST_CASE(argmax_test_nonstd_shape)
...@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment