Commit cb10ae76 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround

parents 498e6c9d 75e6618c
This diff is collapsed.
...@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy) ...@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result; std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result); run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify_range(ref_result, orig_result)); EXPECT(migraphx::verify::verify_range(ref_result, orig_result));
} }
} }
...@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000));
} }
} }
...@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result); run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result));
} }
} }
...@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture) ...@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec)); EXPECT(migraphx::verify::verify_range(vec, cap_vec));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors) ...@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify::verify_range(results_vector, sol));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -80,7 +80,7 @@ void dot_2d_test() ...@@ -80,7 +80,7 @@ void dot_2d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(dot_2d_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
...@@ -131,7 +131,7 @@ void dot_4d_test() ...@@ -131,7 +131,7 @@ void dot_4d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(dot_4d_test<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
...@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test) ...@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_3D_C_test0) TEST_CASE(dot_3D_C_test0)
...@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0) ...@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_3D_C_test1) TEST_CASE(dot_3D_C_test1)
...@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1) ...@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-0.95536130, -0.95536130,
2.27996211}; 2.27996211};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_test1) TEST_CASE(dot_4D_test1)
...@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1) ...@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164, -0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906}; 3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_alpha_beta_test) TEST_CASE(dot_4D_alpha_beta_test)
...@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test) ...@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_alpha_beta_C_test) TEST_CASE(dot_4D_alpha_beta_C_test)
...@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test) ...@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_2D_C_test0) TEST_CASE(dot_2D_C_test0)
...@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0) ...@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -642,7 +642,7 @@ TEST_CASE(dot_vm) ...@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -679,7 +679,7 @@ TEST_CASE(dot_vm) ...@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -726,7 +726,7 @@ TEST_CASE(dot_vm) ...@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -774,7 +774,7 @@ TEST_CASE(dot_vm) ...@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -813,7 +813,7 @@ TEST_CASE(dot_mv) ...@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -851,7 +851,7 @@ TEST_CASE(dot_mv) ...@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -895,7 +895,7 @@ TEST_CASE(dot_mv) ...@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1) ...@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1) ...@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2) ...@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2) ...@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2) ...@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2) ...@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE(dot_dyn_4D_test) TEST_CASE(dot_dyn_4D_test)
...@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
...@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
......
...@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(argmin_test_nonstd_shape) TEST_CASE(argmin_test_nonstd_shape)
...@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(isnan_broadcast_test) TEST_CASE(isnan_broadcast_test)
...@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test) ...@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1}; std::vector<float> correct = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_range(results_vector, correct));
} }
TEST_CASE(squeeze_transpose_test) TEST_CASE(squeeze_transpose_test)
......
This diff is collapsed.
This diff is collapsed.
...@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2.compile(migraphx::make_target("ref")); p2.compile(migraphx::make_target("ref"));
auto result1 = p1.eval({}).back(); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
visit_all(result1, visit_all(result1, result2)(
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); [&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); });
}; };
test_rewrite_pooling(migraphx::op::pooling_mode::max, test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target) ...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125}; std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_range(results_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -947,13 +947,13 @@ TEST_CASE(test_with_type) ...@@ -947,13 +947,13 @@ TEST_CASE(test_with_type)
TEST_CASE(test_multi_index) TEST_CASE(test_multi_index)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}}; migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0})); EXPECT(migraphx::verify::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4})); EXPECT(migraphx::verify::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0})); EXPECT(migraphx::verify::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2})); EXPECT(migraphx::verify::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0})); EXPECT(migraphx::verify::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0})); EXPECT(migraphx::verify::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4})); EXPECT(migraphx::verify::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
} }
TEST_CASE(find_permutation_2d_standard) TEST_CASE(find_permutation_2d_standard)
......
...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness) ...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back(); auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16); std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_range(rv1, rv2));
} }
TEST_CASE(dot_correctness) TEST_CASE(dot_correctness)
...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness) ...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto result2 = p2.eval({{"a", a}, {"b", b}}).back(); auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements()); std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_range(rv1, rv2));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty()) if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name()); MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment