Unverified Commit 69d8d789 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add options to set tolerances inside MIGraphX driver (#2213)

MIGraphX verification by default uses normalized RMS error as the basis for the verification.  This change adds some logic to allow migraphx to do "np.allclose" type of elementwise verification using atol and rtol.

Commit also includes changes to consistently pass "gold" or "expected" results as the second argument for "verify_range()" calls.  Default RMS tolerance inside driver is set to 0.001 which IMO is high for FP32 compared to what we had earlier. Need better defaults
parent e12032fb
......@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, sol));
EXPECT(migraphx::verify::verify_rms_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
......
......@@ -50,9 +50,17 @@ Runs reference and CPU or GPU implementations and checks outputs for consistency
.. include:: ./driver/compile.rst
.. option:: --tolerance [double]
.. option:: --rms-tol [double]
Tolerance for errors (Default: 80)
Tolerance for RMS error (Default: 0.001)
.. option:: --atol [double]
Tolerance for elementwise absolute difference (Default: 0.001)
.. option:: --rtol [double]
Tolerance for elementwise relative difference (Default: 0.001)
.. option:: -i, --per-instruction
......
......@@ -55,7 +55,9 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 |
| --tolerance | Tolerance for errors |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
| --per-instruction \| -i | Verify each instruction |
| --reduce \| -r | Reduce program and verify |
| --iterations \| -n | Number of iterations to run for perf report |
......
......@@ -536,13 +536,19 @@ struct params : command<params>
struct verify : command<verify>
{
compiler c;
double tolerance = 80;
migraphx::verify::tolerance tols;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
{
c.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(tols.rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)"));
ap(tols.atol,
{"--atol"},
ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)"));
ap(tols.rtol,
{"--rtol"},
ap.help("Tolerance for the elementwise relative difference (Default: 0.001)"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
......@@ -567,15 +573,15 @@ struct verify : command<verify>
if(per_instruction)
{
verify_instructions(p, t, c.co, quantize, tolerance);
verify_instructions(p, t, c.co, quantize, tols);
}
else if(reduce)
{
verify_reduced_program(p, t, c.co, quantize, m, tolerance);
verify_reduced_program(p, t, c.co, quantize, m, tols);
}
else
{
verify_program(c.l.file, p, t, c.co, quantize, m, tolerance);
verify_program(c.l.file, p, t, c.co, quantize, m, tols);
}
}
};
......
......@@ -77,24 +77,24 @@ void verify_program(const std::string& name,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
verify::tolerance tols)
{
auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs);
auto ref_outs = run_ref(p, inputs);
auto target_outs = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size();
std::size_t output_num = ref_outs.size();
for(std::size_t i = 0; i < output_num; ++i)
{
if(x[i].get_shape().type() != y[i].get_shape().type() or
x[i].get_shape().lens() != y[i].get_shape().lens())
if(ref_outs[i].get_shape().type() != target_outs[i].get_shape().type() or
ref_outs[i].get_shape().lens() != target_outs[i].get_shape().lens())
{
std::cout << "FAILED: " << name << std::endl;
std::cout << "Shape mismatch {" << x[i].get_shape() << "} != {" << y[i].get_shape()
<< "}" << std::endl;
std::cout << "Shape mismatch {" << ref_outs[i].get_shape() << "} != {"
<< target_outs[i].get_shape() << "}" << std::endl;
}
else
{
verify_args(name, x[i], y[i], tolerance);
verify_args(name, target_outs[i], verify::expected{ref_outs[i]}, tols);
}
}
}
......@@ -103,7 +103,7 @@ void verify_instructions(const program& prog,
const target& t,
compile_options options,
precision quantize,
double tolerance)
verify::tolerance tols)
{
const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog))
......@@ -134,8 +134,7 @@ void verify_instructions(const program& prog,
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl;
verify_program(
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
verify_program(ins.name(), p, t, options, quantize, create_param_map(p, false), tols);
}
catch(...)
{
......@@ -151,7 +150,7 @@ void verify_reduced(program p,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
verify::tolerance tols)
{
auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n);
......@@ -160,7 +159,7 @@ void verify_reduced(program p,
std::cout << p << std::endl;
try
{
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
verify_program(std::to_string(n), p, t, options, quantize, inputs, tols);
}
catch(const std::exception& e)
{
......@@ -174,7 +173,7 @@ void verify_reduced_program(const program& p,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
verify::tolerance tols)
{
const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
......@@ -187,7 +186,7 @@ void verify_reduced_program(const program& p,
std::cout << "Skip: " << i << std::endl;
continue;
}
verify_reduced(p, i, t, options, quantize, inputs, tolerance);
verify_reduced(p, i, t, options, quantize, inputs, tols);
}
}
......
......@@ -26,6 +26,7 @@
#include "precision.hpp"
#include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace migraphx {
namespace driver {
......@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {},
double tolerance = 100);
verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
double tolerance = 80);
verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {},
double tolerance = 80);
verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
......
......@@ -29,10 +29,13 @@
#include <functional>
#include <iostream>
#include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE)
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace verify {
......@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max();
}
template <class R>
double get_rms_tol(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template <class T>
struct expected
{
expected() = default;
explicit expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
auto idx = mismatch_idx(r1, r2, [&](auto x, auto y) {
return abs_diff(double(x), double(y)) < tols.atol + tols.rtol * std::abs(double(y));
});
return idx >= range_distance(r1);
}
return false;
}
template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr)
bool verify_rms_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_rms_error = nullptr)
{
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
double threshold = get_rms_tol(r1, tolerance);
auto error = rms_range(r1, r2);
if(out_error != nullptr)
*out_error = error;
if(out_rms_error != nullptr)
*out_rms_error = error;
return error <= threshold;
}
template <class R1, class R2>
bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool ewise_verify = true;
if(enabled(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE{}))
{
ewise_verify = allclose(r1, r2.data(), tols);
}
if(out_rms_error != nullptr)
*out_rms_error = rms_error;
return rms_error <= tols.rms_tol and ewise_verify;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template <class R1, class R2>
bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
return verify_rms_range(r2, r1, tols, out_rms_error);
}
} // namespace verify
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -31,11 +31,15 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT
bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg,
double tolerance = 80);
MIGRAPHX_EXPORT bool verify_args(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
verify::tolerance);
MIGRAPHX_EXPORT bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -28,19 +28,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg,
double tolerance)
const verify::expected<argument>& ref_arg,
verify::tolerance tols)
{
bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error;
passed = verify::verify_range(ref, target, tolerance, &error);
visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double rms_error;
passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
if(not passed)
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
std::cout << "error: " << error << std::endl;
std::cout << "RMS Error: " << rms_error << std::endl;
if(ref.size() < 32)
std::cout << "ref:" << ref << std::endl;
if(target.size() < 32)
......@@ -93,5 +94,16 @@ bool verify_args(const std::string& name,
return passed;
}
bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance)
{
double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::tolerance tols{rms_tol};
return verify_args(name, target_arg, ref_arg, tols);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy)
migraphx::parameter_map pp;
std::vector<float> a_vec(ss.elements(), -1);
std::vector<float> b_vec(ss.elements(), 2);
std::vector<float> c_vec(ss.elements(), 0);
pp["a"] = migraphx::argument(ss, a_vec.data());
pp["b"] = migraphx::argument(ss, b_vec.data());
std::vector<float> gpu_result;
......@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy)
auto result = p.eval(pp).back();
std::vector<float> results_vector(ss.elements(), -1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(c_vec, results_vector));
std::vector<float> gold_vec(ss.elements(), 0);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_vec));
}
TEST_CASE(arguments_lifetime)
......
......@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir)
auto inputs = generate_params(ref);
auto mlir = create_program_from_mlir(mmlir);
return migraphx::verify_args("mlir", run_ref(ref, inputs), run_gpu(mlir, inputs));
return migraphx::verify_args_with_tolerance(
"mlir", run_gpu(mlir, inputs), migraphx::verify::expected{run_ref(ref, inputs)});
}
TEST_CASE(conv)
......
......@@ -40,7 +40,6 @@
TEST_CASE(gpu_target_copy)
{
migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::target ref_t = migraphx::make_target("ref");
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L);
......@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify::verify_range(val_orig, val_final));
EXPECT(migraphx::verify::verify_rms_range(val_orig, val_final));
}
TEST_CASE(int8_quantization)
......@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much
// earlier stage.
if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5));
EXPECT(migraphx::verify::verify_range_with_tolerance(
gpu_result,
migraphx::verify::expected{ref_result},
migraphx::verify::tolerance{0.01}));
else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
}
}
......
......@@ -47,7 +47,7 @@ TEST_CASE(averagepool_notset_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {12};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(averagepool_nt_cip_test)
......@@ -65,7 +65,7 @@ TEST_CASE(averagepool_nt_cip_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {8.33333};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(batch_norm_flat_test)
......@@ -76,15 +76,15 @@ TEST_CASE(batch_norm_flat_test)
migraphx::shape x_shape{migraphx::shape::float_type, {10}};
migraphx::shape c_shape(migraphx::shape::float_type, {1});
std::vector<float> x_data = {1.6524342,
-0.51048076,
0.32543048,
2.4410043,
2.0833702,
0.44981122,
1.0044622,
-0.24006313,
-0.43065986,
0.07626268};
-0.51048076,
0.32543048,
2.4410043,
2.0833702,
0.44981122,
1.0044622,
-0.24006313,
-0.43065986,
0.07626268};
std::vector<float> scale_data = {-0.02927135};
std::vector<float> bias_data = {0.42347777};
std::vector<float> mean_data = {-0.00449735};
......@@ -111,7 +111,7 @@ TEST_CASE(batch_norm_flat_test)
0.43305403,
0.4408022,
0.42019472};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(batch_norm_rank_2_test)
......@@ -148,7 +148,7 @@ TEST_CASE(batch_norm_rank_2_test)
9.89948504,
9.89948504,
12.72790933};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(batch_norm_1d_test)
......@@ -184,7 +184,7 @@ TEST_CASE(batch_norm_1d_test)
0.4927, 0.771, -1.956, -2.123, -0.664, -0.583, -0.7207, -0.5127};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(batch_norm_2d_test)
......@@ -250,7 +250,7 @@ TEST_CASE(batch_norm_2d_test)
-2.76707697e+00, 1.47579327e+01, 4.94736385e+00, 2.68847847e+01, -6.49254417e+00,
1.94286156e+00, -7.19223642e+00, -3.70413971e+00, -4.04303551e-01, -1.01827660e+01,
1.49476433e+00};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(batch_norm_3d_test)
......@@ -292,7 +292,7 @@ TEST_CASE(batch_norm_3d_test)
6.098, 11.03, 2.81, 2.81, 2.81, 12.125, 3.143, 8.53, 17.52, 4.938, 15.71,
1.347, 4.938, 1.167, 6.098, 12.67, 12.67, 4.453, 4.453, -0.4768, 12.67};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(celu_verify_test)
......@@ -309,12 +309,12 @@ TEST_CASE(celu_verify_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct(6);
std::vector<float> gold(6);
float alpha = 0.5;
std::transform(data.begin(), data.end(), correct.begin(), [&](auto x) {
std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha));
});
EXPECT(migraphx::verify::verify_range(result_vector, correct));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(clip_args_type_mismatch)
......@@ -330,7 +330,7 @@ TEST_CASE(clip_args_type_mismatch)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(depthtospace_simple_test)
......@@ -348,7 +348,7 @@ TEST_CASE(depthtospace_simple_test)
std::vector<float> gold = {0, 12, 1, 13, 2, 14, 24, 36, 25, 37, 26, 38, 3, 15, 4, 16,
5, 17, 27, 39, 28, 40, 29, 41, 6, 18, 7, 19, 8, 20, 30, 42,
31, 43, 32, 44, 9, 21, 10, 22, 11, 23, 33, 45, 34, 46, 35, 47};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(spacetodepth_simple_test)
......@@ -366,7 +366,7 @@ TEST_CASE(spacetodepth_simple_test)
std::vector<float> gold = {0, 2, 4, 12, 14, 16, 24, 26, 28, 36, 38, 40, 1, 3, 5, 13,
15, 17, 25, 27, 29, 37, 39, 41, 6, 8, 10, 18, 20, 22, 30, 32,
34, 42, 44, 46, 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(spacetodepth_depthtospace_test)
......@@ -374,11 +374,11 @@ TEST_CASE(spacetodepth_depthtospace_test)
// space to depth
auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx");
p1.compile(migraphx::make_target("ref"));
std::vector<float> data_in(48);
std::iota(std::begin(data_in), std::end(data_in), 0);
std::vector<float> gold_data_in(48);
std::iota(std::begin(gold_data_in), std::end(gold_data_in), 0);
migraphx::shape s_x_1{migraphx::shape::float_type, {1, 2, 4, 6}};
migraphx::parameter_map pp1;
pp1["x"] = migraphx::argument(s_x_1, data_in.data());
pp1["x"] = migraphx::argument(s_x_1, gold_data_in.data());
auto result1 = p1.eval(pp1).back();
// depth to space
auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx");
......@@ -388,7 +388,7 @@ TEST_CASE(spacetodepth_depthtospace_test)
auto result2 = p2.eval(pp2).back();
std::vector<float> result_vector2;
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vector2, data_in));
EXPECT(migraphx::verify::verify_rms_range(result_vector2, gold_data_in));
}
TEST_CASE(eyelike_verify_test)
......@@ -405,8 +405,8 @@ TEST_CASE(eyelike_verify_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.};
EXPECT(migraphx::verify::verify_range(result_vector, eyelike_mat));
std::vector<float> gold_eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_eyelike_mat));
}
TEST_CASE(eyelike_verify_negk_test)
......@@ -423,8 +423,8 @@ TEST_CASE(eyelike_verify_negk_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(result_vector, eyelike_mat));
std::vector<float> gold_eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_eyelike_mat));
}
TEST_CASE(gather_elements)
......@@ -447,7 +447,7 @@ TEST_CASE(gather_elements)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(gemm_test)
......@@ -491,7 +491,7 @@ TEST_CASE(gemm_test)
0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702,
0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703,
1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(gemm_half_test)
......@@ -535,7 +535,7 @@ TEST_CASE(gemm_half_test)
2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155,
1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(greaterorequal_test)
......@@ -556,7 +556,7 @@ TEST_CASE(greaterorequal_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(hardsigmoid_verify_test)
......@@ -580,7 +580,7 @@ TEST_CASE(hardsigmoid_verify_test)
std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, std::min(x * alpha + beta, 1.0f));
});
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_else_test)
......@@ -602,7 +602,7 @@ TEST_CASE(if_else_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_else_test_inlined)
......@@ -621,7 +621,7 @@ TEST_CASE(if_else_test_inlined)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_then_test)
......@@ -644,7 +644,7 @@ TEST_CASE(if_then_test)
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_then_test_inlined)
......@@ -663,7 +663,7 @@ TEST_CASE(if_then_test_inlined)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_literal_test)
......@@ -688,14 +688,14 @@ TEST_CASE(if_literal_test)
{
auto result_vector = run_prog(true);
std::vector<float> gold = {1, 2, 3, 4, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
// else branch
{
auto result_vector = run_prog(false);
std::vector<float> gold = {5, 4, 3, 2, 1};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
}
......@@ -726,7 +726,7 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_then_else_multi_output_shapes_test)
......@@ -757,7 +757,7 @@ TEST_CASE(if_then_else_multi_output_shapes_test)
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(if_pl_test)
......@@ -789,14 +789,14 @@ TEST_CASE(if_pl_test)
{
auto result_vector = run_prog(true);
std::vector<float> gold = {2, 3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
// else branch
{
auto result_vector = run_prog(false);
std::vector<float> gold = {1, 2, 3, 4, 5, 6};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
}
......@@ -835,8 +835,8 @@ TEST_CASE(if_tuple_test)
auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_range(results.at(1), gold1));
EXPECT(migraphx::verify::verify_rms_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_rms_range(results.at(1), gold1));
}
// else branch
......@@ -844,8 +844,8 @@ TEST_CASE(if_tuple_test)
auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_range(results.at(1), gold1));
EXPECT(migraphx::verify::verify_rms_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_rms_range(results.at(1), gold1));
}
}
......@@ -876,7 +876,7 @@ TEST_CASE(instance_norm_test)
2.54919,
3.32379,
4.09838};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(instance_norm_dyn_batch_test)
......@@ -918,7 +918,7 @@ TEST_CASE(instance_norm_dyn_batch_test)
2.54919,
3.32379,
4.09838};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(instance_norm_3d_test)
......@@ -947,7 +947,7 @@ TEST_CASE(instance_norm_3d_test)
3.18218,
4.05505};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(lessorequal_test)
......@@ -968,7 +968,7 @@ TEST_CASE(lessorequal_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(lpnormalization_1norm)
......@@ -996,7 +996,7 @@ TEST_CASE(lpnormalization_1norm)
3.f / 7.f,
0.f,
0.f};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(lpnormalization_2norm)
......@@ -1012,19 +1012,19 @@ TEST_CASE(lpnormalization_2norm)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct{0.f,
2.f / 3.f,
-2.f / 3.f,
1.f / 3.f,
1.f / 6.f,
-5.f / 6.f,
3.f / 6.f,
-1.f / 6.f,
-4.f / 5.f,
3.f / 5.f,
0.f,
0.f};
EXPECT(migraphx::verify::verify_range(result_vector, correct));
std::vector<float> gold{0.f,
2.f / 3.f,
-2.f / 3.f,
1.f / 3.f,
1.f / 6.f,
-5.f / 6.f,
3.f / 6.f,
-1.f / 6.f,
-4.f / 5.f,
3.f / 5.f,
0.f,
0.f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mean_broadcast_test)
......@@ -1055,7 +1055,7 @@ TEST_CASE(mean_broadcast_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(24, 3);
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mean_test)
......@@ -1082,7 +1082,7 @@ TEST_CASE(mean_test)
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data;
std::vector<double> gold(num_elms, mean);
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mean_integral_test)
......@@ -1109,7 +1109,7 @@ TEST_CASE(mean_integral_test)
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data;
std::vector<int> gold(num_elms, mean);
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mod_test)
......@@ -1136,7 +1136,7 @@ TEST_CASE(mod_test)
std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2,
5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mod_test_different_types)
......@@ -1164,7 +1164,7 @@ TEST_CASE(mod_test_different_types)
std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2,
5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mod_test_fmod)
......@@ -1193,7 +1193,7 @@ TEST_CASE(mod_test_fmod)
10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2,
7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mod_test_fmod_different_types)
......@@ -1223,7 +1223,7 @@ TEST_CASE(mod_test_fmod_different_types)
10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2,
7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(nonzero_test)
......@@ -1242,7 +1242,7 @@ TEST_CASE(nonzero_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 1, 0, 0, 1, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_downsample_f_test)
......@@ -1263,7 +1263,7 @@ TEST_CASE(resize_downsample_f_test)
std::vector<float> gold = {0.0f, 3.0f};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_upsample_linear_ac_test)
......@@ -1298,7 +1298,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
11.0f / 3,
4};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_upsample_linear_test)
......@@ -1319,7 +1319,7 @@ TEST_CASE(resize_upsample_linear_test)
std::vector<float> gold = {
1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_upsample_pf_test)
......@@ -1340,7 +1340,7 @@ TEST_CASE(resize_upsample_pf_test)
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(reversesequence_4D_verify_test)
......@@ -1361,7 +1361,7 @@ TEST_CASE(reversesequence_4D_verify_test)
std::vector<float> gold = {
8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(reversesequence_batch_verify_test)
......@@ -1382,7 +1382,7 @@ TEST_CASE(reversesequence_batch_verify_test)
std::vector<float> gold = {
0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(reversesequence_time_verify_test)
......@@ -1403,7 +1403,7 @@ TEST_CASE(reversesequence_time_verify_test)
std::vector<float> gold = {
3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(selu_test)
......@@ -1423,7 +1423,7 @@ TEST_CASE(selu_test)
std::vector<float> gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(size_verify_test)
......@@ -1457,7 +1457,7 @@ TEST_CASE(slice_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2, 3};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(slice_5arg_test)
......@@ -1477,7 +1477,7 @@ TEST_CASE(slice_5arg_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(slice_reverse_test)
......@@ -1497,7 +1497,7 @@ TEST_CASE(slice_reverse_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(slice_step_test)
......@@ -1517,7 +1517,7 @@ TEST_CASE(slice_step_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 12};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(softplus_test)
......@@ -1538,7 +1538,7 @@ TEST_CASE(softplus_test)
std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); });
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(softsign_test)
......@@ -1559,7 +1559,7 @@ TEST_CASE(softsign_test)
std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); });
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(upsample_test)
......@@ -1578,7 +1578,7 @@ TEST_CASE(upsample_test)
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(where_test)
......@@ -1620,7 +1620,7 @@ TEST_CASE(where_test)
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
......@@ -1645,7 +1645,7 @@ TEST_CASE(trilu_test)
std::vector<float> gold = {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_batch_diff_k_test)
......@@ -1656,7 +1656,7 @@ TEST_CASE(trilu_batch_diff_k_test)
std::vector<float> gold = {0, 0, 3, 0, 0, 0, 0, 0, 9, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_lower_test)
......@@ -1667,7 +1667,7 @@ TEST_CASE(trilu_lower_test)
std::vector<float> gold = {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_out_k_test)
......@@ -1678,7 +1678,7 @@ TEST_CASE(trilu_out_k_test)
std::vector<float> gold(12, 0);
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_row_one_test)
......@@ -1689,7 +1689,7 @@ TEST_CASE(trilu_row_one_test)
std::vector<float> gold = {0, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -83,7 +83,7 @@ TEST_CASE(param_add)
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto fs = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs);
if(add_return)
{
......@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify::verify_range(ref_result, orig_result));
EXPECT(migraphx::verify::verify_rms_range(ref_result, orig_result));
}
}
......@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000));
EXPECT(migraphx::verify::verify_range_with_tolerance(
quant_result,
migraphx::verify::expected{no_quant_result},
migraphx::verify::tolerance{0.003}));
}
}
......@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result));
EXPECT(migraphx::verify::verify_rms_range(quant_result, no_quant_result));
}
}
......@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(vec, cap_vec));
EXPECT(migraphx::verify::verify_rms_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -42,7 +42,7 @@ TEST_CASE(abs_test)
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(abs_dyn_test)
......@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test)
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
......@@ -45,7 +45,7 @@ TEST_CASE(acos_test)
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(acos_dyn_test)
......@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test)
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
......@@ -45,7 +45,7 @@ TEST_CASE(acosh_test)
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(acosh_dyn_test)
......@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test)
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
......@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test)
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(add_multibroadcast_test)
......@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test)
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(add_test)
......@@ -91,7 +91,7 @@ TEST_CASE(add_test)
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(add_dyn_test)
......@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test)
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(fp16_test)
......@@ -134,7 +134,7 @@ TEST_CASE(fp16_test)
std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(fp32_fp16_test)
......@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test)
auto result = p.eval({}).back();
std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res, gold_res));
EXPECT(migraphx::verify::verify_rms_range(res, gold_res));
};
test_case({"all"});
......
......@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0)
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_1)
......@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1)
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_2)
......@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2)
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_neg_2)
......@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2)
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmax_dyn_test)
......@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test)
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_nonstd_shape)
......@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment