Unverified Commit 68a9a23f authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

add verify namespace (#1952)

parent c4765a6d
...@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro ...@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify::verify_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU. An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target. By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return error <= threshold; return error <= threshold;
} }
} // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name, ...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify_range(ref, target, tolerance, &error); passed = verify::verify_range(ref, target, tolerance, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
...@@ -45,27 +45,27 @@ bool verify_args(const std::string& name, ...@@ -45,27 +45,27 @@ bool verify_args(const std::string& name,
std::cout << "ref:" << ref << std::endl; std::cout << "ref:" << ref << std::endl;
if(target.size() < 32) if(target.size() < 32)
std::cout << "target:" << target << std::endl; std::cout << "target:" << target << std::endl;
if(range_zero(ref)) if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl; std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
auto mxdiff = max_diff(ref, target); auto mxdiff = verify::max_diff(ref, target);
std::cout << "Max diff: " << mxdiff << std::endl; std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(ref, target, float_equal); auto idx = verify::mismatch_idx(ref, target, float_equal);
if(idx < range_distance(ref)) if(idx < verify::range_distance(ref))
{ {
std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
<< std::endl; << std::endl;
} }
auto ref_nan_idx = find_idx(ref, not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl; << ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite); auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
...@@ -73,27 +73,27 @@ bool verify_args(const std::string& name, ...@@ -73,27 +73,27 @@ bool verify_args(const std::string& name,
} }
else else
{ {
if(range_zero(ref)) if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl; std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
// auto mxdiff = max_diff(ref, target); // auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl; // std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal); // auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < range_distance(ref)) // if(idx < verify::range_distance(ref))
// { // {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] // std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl; // << std::endl;
// } // }
auto ref_nan_idx = find_idx(ref, not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl; << ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite); auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
......
...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test) ...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy) ...@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy)
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> results_vector(ss.elements(), -1); std::vector<float> results_vector(ss.elements(), -1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c_vec, results_vector)); EXPECT(migraphx::verify::verify_range(c_vec, results_vector));
} }
TEST_CASE(arguments_lifetime) TEST_CASE(arguments_lifetime)
......
...@@ -52,7 +52,7 @@ TEST_CASE(gpu_target_copy) ...@@ -52,7 +52,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final; std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify_range(val_orig, val_final)); EXPECT(migraphx::verify::verify_range(val_orig, val_final));
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization)
...@@ -118,9 +118,9 @@ TEST_CASE(int8_quantization) ...@@ -118,9 +118,9 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify_range(ref_result, gpu_result, 1e5)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5));
else else
EXPECT(migraphx::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
} }
} }
......
This diff is collapsed.
...@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy) ...@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result; std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result); run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify_range(ref_result, orig_result)); EXPECT(migraphx::verify::verify_range(ref_result, orig_result));
} }
} }
...@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000));
} }
} }
...@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result); run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result));
} }
} }
...@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture) ...@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec)); EXPECT(migraphx::verify::verify_range(vec, cap_vec));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors) ...@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify::verify_range(results_vector, sol));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -80,7 +80,7 @@ void dot_2d_test() ...@@ -80,7 +80,7 @@ void dot_2d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(dot_2d_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
...@@ -131,7 +131,7 @@ void dot_4d_test() ...@@ -131,7 +131,7 @@ void dot_4d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(dot_4d_test<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
...@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test) ...@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_3D_C_test0) TEST_CASE(dot_3D_C_test0)
...@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0) ...@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_3D_C_test1) TEST_CASE(dot_3D_C_test1)
...@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1) ...@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-0.95536130, -0.95536130,
2.27996211}; 2.27996211};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_test1) TEST_CASE(dot_4D_test1)
...@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1) ...@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164, -0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906}; 3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_alpha_beta_test) TEST_CASE(dot_4D_alpha_beta_test)
...@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test) ...@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_alpha_beta_C_test) TEST_CASE(dot_4D_alpha_beta_C_test)
...@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test) ...@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_2D_C_test0) TEST_CASE(dot_2D_C_test0)
...@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0) ...@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -642,7 +642,7 @@ TEST_CASE(dot_vm) ...@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -679,7 +679,7 @@ TEST_CASE(dot_vm) ...@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -726,7 +726,7 @@ TEST_CASE(dot_vm) ...@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -774,7 +774,7 @@ TEST_CASE(dot_vm) ...@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -813,7 +813,7 @@ TEST_CASE(dot_mv) ...@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -851,7 +851,7 @@ TEST_CASE(dot_mv) ...@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -895,7 +895,7 @@ TEST_CASE(dot_mv) ...@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1) ...@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1) ...@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2) ...@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2) ...@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2) ...@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2) ...@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE(dot_dyn_4D_test) TEST_CASE(dot_dyn_4D_test)
...@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
...@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
......
...@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(argmin_test_nonstd_shape) TEST_CASE(argmin_test_nonstd_shape)
...@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(isnan_broadcast_test) TEST_CASE(isnan_broadcast_test)
...@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test) ...@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1}; std::vector<float> correct = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_range(results_vector, correct));
} }
TEST_CASE(squeeze_transpose_test) TEST_CASE(squeeze_transpose_test)
......
This diff is collapsed.
This diff is collapsed.
...@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2.compile(migraphx::make_target("ref")); p2.compile(migraphx::make_target("ref"));
auto result1 = p1.eval({}).back(); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
visit_all(result1, visit_all(result1, result2)(
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); [&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); });
}; };
test_rewrite_pooling(migraphx::op::pooling_mode::max, test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target) ...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125}; std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_range(results_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -947,13 +947,13 @@ TEST_CASE(test_with_type) ...@@ -947,13 +947,13 @@ TEST_CASE(test_with_type)
TEST_CASE(test_multi_index) TEST_CASE(test_multi_index)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}}; migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0})); EXPECT(migraphx::verify::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4})); EXPECT(migraphx::verify::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0})); EXPECT(migraphx::verify::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2})); EXPECT(migraphx::verify::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0})); EXPECT(migraphx::verify::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0})); EXPECT(migraphx::verify::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4})); EXPECT(migraphx::verify::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
} }
TEST_CASE(find_permutation_2d_standard) TEST_CASE(find_permutation_2d_standard)
......
...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness) ...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back(); auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16); std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_range(rv1, rv2));
} }
TEST_CASE(dot_correctness) TEST_CASE(dot_correctness)
...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness) ...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto result2 = p2.eval({{"a", a}, {"b", b}}).back(); auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements()); std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_range(rv1, rv2));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment