Unverified Commit 68a9a23f authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

add verify namespace (#1952)

parent c4765a6d
......@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
EXPECT(migraphx::verify::verify_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
......
......@@ -35,6 +35,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace verify {
// Compute the value of a range
template <class R>
......@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return error <= threshold;
}
} // namespace verify
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error;
passed = verify_range(ref, target, tolerance, &error);
passed = verify::verify_range(ref, target, tolerance, &error);
if(not passed)
{
// TODO: Check for nans
......@@ -45,27 +45,27 @@ bool verify_args(const std::string& name,
std::cout << "ref:" << ref << std::endl;
if(target.size() < 32)
std::cout << "target:" << target << std::endl;
if(range_zero(ref))
if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target))
if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl;
auto mxdiff = max_diff(ref, target);
auto mxdiff = verify::max_diff(ref, target);
std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(ref, target, float_equal);
if(idx < range_distance(ref))
auto idx = verify::mismatch_idx(ref, target, float_equal);
if(idx < verify::range_distance(ref))
{
std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
<< std::endl;
}
auto ref_nan_idx = find_idx(ref, not_finite);
auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite);
auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl;
......@@ -73,27 +73,27 @@ bool verify_args(const std::string& name,
}
else
{
if(range_zero(ref))
if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target))
if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl;
// auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < range_distance(ref))
// if(idx < verify::range_distance(ref))
// {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl;
// }
auto ref_nan_idx = find_idx(ref, not_finite);
auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite);
auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl;
......
......@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result));
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy)
auto result = p.eval(pp).back();
std::vector<float> results_vector(ss.elements(), -1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c_vec, results_vector));
EXPECT(migraphx::verify::verify_range(c_vec, results_vector));
}
TEST_CASE(arguments_lifetime)
......
......@@ -52,7 +52,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify_range(val_orig, val_final));
EXPECT(migraphx::verify::verify_range(val_orig, val_final));
}
TEST_CASE(int8_quantization)
......@@ -118,9 +118,9 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much
// earlier stage.
if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify_range(ref_result, gpu_result, 1e5));
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5));
else
EXPECT(migraphx::verify_range(ref_result, gpu_result));
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
}
}
......
This diff is collapsed.
......@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify_range(ref_result, orig_result));
EXPECT(migraphx::verify::verify_range(ref_result, orig_result));
}
}
......@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000));
}
}
......@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result));
}
}
......@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
EXPECT(migraphx::verify::verify_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
EXPECT(migraphx::verify::verify_range(results_vector, sol));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -80,7 +80,7 @@ void dot_2d_test()
auto result = p.eval({}).back();
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>)
......@@ -131,7 +131,7 @@ void dot_4d_test()
auto result = p.eval({}).back();
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>)
......@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_3D_C_test0)
......@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_3D_C_test1)
......@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-0.95536130,
2.27996211};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_4D_test1)
......@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_4D_alpha_beta_test)
......@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_4D_alpha_beta_C_test)
......@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_2D_C_test0)
......@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
......@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4)
......@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......
......@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(argmin_test_nonstd_shape)
......@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(isnan_broadcast_test)
......@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify_range(results_vector, correct));
EXPECT(migraphx::verify::verify_range(results_vector, correct));
}
TEST_CASE(squeeze_transpose_test)
......
This diff is collapsed.
This diff is collapsed.
......@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2.compile(migraphx::make_target("ref"));
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1,
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
visit_all(result1, result2)(
[&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); });
};
test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
......@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -947,13 +947,13 @@ TEST_CASE(test_with_type)
TEST_CASE(test_multi_index)
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
EXPECT(migraphx::verify::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
}
TEST_CASE(find_permutation_2d_standard)
......
......@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2));
EXPECT(migraphx::verify::verify_range(rv1, rv2));
}
TEST_CASE(dot_correctness)
......@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2));
EXPECT(migraphx::verify::verify_range(rv1, rv2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment