Commit cb10ae76 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround

parents 498e6c9d 75e6618c
This diff is collapsed.
......@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify_range(ref_result, orig_result));
EXPECT(migraphx::verify::verify_range(ref_result, orig_result));
}
}
......@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000));
}
}
......@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result));
}
}
......@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
EXPECT(migraphx::verify::verify_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
EXPECT(migraphx::verify::verify_range(results_vector, sol));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -80,7 +80,7 @@ void dot_2d_test()
auto result = p.eval({}).back();
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>)
......@@ -131,7 +131,7 @@ void dot_4d_test()
auto result = p.eval({}).back();
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>)
......@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_3D_C_test0)
......@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_3D_C_test1)
......@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-0.95536130,
2.27996211};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_4D_test1)
......@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_4D_alpha_beta_test)
......@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_4D_alpha_beta_C_test)
......@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
EXPECT(migraphx::verify::verify_range(m, m_res));
}
TEST_CASE(dot_2D_C_test0)
......@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
......@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
EXPECT(migraphx::verify::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4)
......@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
{
......@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back();
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
EXPECT(migraphx::verify::verify_range(m, gold));
}
}
......
......@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(argmin_test_nonstd_shape)
......@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(isnan_broadcast_test)
......@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify_range(results_vector, correct));
EXPECT(migraphx::verify::verify_range(results_vector, correct));
}
TEST_CASE(squeeze_transpose_test)
......
This diff is collapsed.
This diff is collapsed.
......@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2.compile(migraphx::make_target("ref"));
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1,
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
visit_all(result1, result2)(
[&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); });
};
test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
......@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -947,13 +947,13 @@ TEST_CASE(test_with_type)
TEST_CASE(test_multi_index)
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
EXPECT(migraphx::verify::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
}
TEST_CASE(find_permutation_2d_standard)
......
......@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2));
EXPECT(migraphx::verify::verify_range(rv1, rv2));
}
TEST_CASE(dot_correctness)
......@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2));
EXPECT(migraphx::verify::verify_range(rv1, rv2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens());
normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment