Unverified Commit 25e8cf0b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into test_onnx_zoo

parents a313a68e 635502be
...@@ -181,6 +181,24 @@ TEST_CASE(argmax_test) ...@@ -181,6 +181,24 @@ TEST_CASE(argmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"x",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("argmax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(argmin_test) TEST_CASE(argmin_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -273,6 +291,51 @@ TEST_CASE(averagepool_3d_test) ...@@ -273,6 +291,51 @@ TEST_CASE(averagepool_3d_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(averagepool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}, {5, 5, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 3}}}),
l0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("averagepool_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(averagepool_dyn_autopad_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("averagepool_dyn_autopad_error_test.onnx", options); }));
}
TEST_CASE(averagepool_dyn_asym_padding_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("averagepool_dyn_asym_padding_error_test.onnx", options); }));
}
TEST_CASE(averagepool_dyn_cip_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("averagepool_dyn_cip_error_test.onnx", options); }));
}
TEST_CASE(averagepool_notset_test) TEST_CASE(averagepool_notset_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1055,6 +1118,25 @@ TEST_CASE(conv_dynamic_batch_test) ...@@ -1055,6 +1118,25 @@ TEST_CASE(conv_dynamic_batch_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(conv_dynamic_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 6, 0}, {3, 3, 0}, {32, 32, 0}, {32, 32, 0}}});
auto x1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto x2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
auto x3 = mm->add_instruction(migraphx::make_op("convolution"), x0, x1);
auto x4 = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), x2, x3);
auto x5 = mm->add_instruction(migraphx::make_op("add"), x3, x4);
mm->add_return({x5});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 6, 0};
auto prog = migraphx::parse_onnx("conv_dynamic_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(conv_dynamic_img_test) TEST_CASE(conv_dynamic_img_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1755,6 +1837,16 @@ migraphx::program create_external_data_prog() ...@@ -1755,6 +1837,16 @@ migraphx::program create_external_data_prog()
return p; return p;
} }
TEST_CASE(external_constant_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{{migraphx::shape::int64_type, {3}}, {0, 1, 2}});
auto prog = optimize_onnx("external_constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(external_data_test) TEST_CASE(external_data_test)
{ {
migraphx::program p = create_external_data_prog(); migraphx::program p = create_external_data_prog();
...@@ -1914,6 +2006,23 @@ TEST_CASE(flatten_nonstd_test) ...@@ -1914,6 +2006,23 @@ TEST_CASE(flatten_nonstd_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(flatten_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto ret = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), c0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("flatten_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(floor_test) TEST_CASE(floor_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2026,64 +2135,64 @@ TEST_CASE(gemm_test) ...@@ -2026,64 +2135,64 @@ TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {8, 6}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto alpha = 2.f; auto alpha = 0.5f;
auto beta = 2.0f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_test.onnx"); auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_ex_test) TEST_CASE(gemm_no_C_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 0.5f; auto alpha = 2.f;
auto beta = 0.8f; auto beta = 2.0f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b); mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_test.onnx"); auto prog = optimize_onnx("gemm_no_C_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_ex_brcst_test) TEST_CASE(gemm_brcst_C_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 1}});
std::vector<std::size_t> out_lens{1, 1, 6, 7}; std::vector<std::size_t> out_lens{6, 7};
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
...@@ -2093,7 +2202,7 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -2093,7 +2202,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx"); auto prog = optimize_onnx("gemm_brcst_C_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -2101,17 +2210,17 @@ TEST_CASE(gemm_half_test) ...@@ -2101,17 +2210,17 @@ TEST_CASE(gemm_half_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction( t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7}; std::vector<std::size_t> lens = {6, 7};
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction( l2 = mm->add_instruction(
...@@ -2127,6 +2236,73 @@ TEST_CASE(gemm_half_test) ...@@ -2127,6 +2236,73 @@ TEST_CASE(gemm_half_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_dyn_inner_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {6, 6, 0}}});
auto l1 = mm->add_parameter(
"B", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {7, 7, 0}}});
auto alpha = 0.5f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 8};
auto prog = migraphx::parse_onnx("gemm_dyn_inner_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_outer_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto alpha = 2.f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7};
auto prog = migraphx::parse_onnx("gemm_dyn_outer_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_rank_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
}
TEST_CASE(globalavgpool_test) TEST_CASE(globalavgpool_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2144,6 +2320,28 @@ TEST_CASE(globalavgpool_test) ...@@ -2144,6 +2320,28 @@ TEST_CASE(globalavgpool_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(globalavgpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {16, 16}},
{"padding", {0, 0, 0, 0}}}),
input);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("globalavgpool_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(globallppool_test) TEST_CASE(globallppool_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2161,6 +2359,29 @@ TEST_CASE(globallppool_test) ...@@ -2161,6 +2359,29 @@ TEST_CASE(globallppool_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(globallppool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {16, 32, 0}, {16, 32, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm},
{"dyn_global", true},
{"padding", {0, 0, 0, 0}},
{"lengths", {}}}),
input);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {16, 32, 0};
auto prog = migraphx::parse_onnx("globallppool_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(globalmaxpool_test) TEST_CASE(globalmaxpool_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2178,6 +2399,28 @@ TEST_CASE(globalmaxpool_test) ...@@ -2178,6 +2399,28 @@ TEST_CASE(globalmaxpool_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(globalmaxpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {32, 32, 0}, {32, 32, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"lengths", {32, 32}},
{"padding", {0, 0, 0, 0}}}),
input);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("globalmaxpool_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(greater_test) TEST_CASE(greater_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3256,6 +3499,92 @@ TEST_CASE(matmul_vv_test) ...@@ -3256,6 +3499,92 @@ TEST_CASE(matmul_vv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmul_dyn_mm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {1, 5, 3}}});
auto ret = migraphx::add_apply_alpha_beta(*mm, {l0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {1, 5, 3}};
auto prog = parse_onnx("matmul_dyn_mm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_mv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = migraphx::add_apply_alpha_beta(*mm, {l0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
auto prog = parse_onnx("matmul_dyn_mv_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_vm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {4, 10, 8}}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = migraphx::add_apply_alpha_beta(*mm, {sl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {4, 10, 8}};
auto prog = parse_onnx("matmul_dyn_vm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_vv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 8, 7};
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {dd}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {dd}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res =
migraphx::add_apply_alpha_beta(*mm, {sl0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
auto prog = parse_onnx("matmul_dyn_vv_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_broadcast_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmul_dyn_broadcast_error.onnx", options); }));
}
TEST_CASE(matmulinteger_test) TEST_CASE(matmulinteger_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3269,6 +3598,13 @@ TEST_CASE(matmulinteger_test) ...@@ -3269,6 +3598,13 @@ TEST_CASE(matmulinteger_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmulinteger_dyn_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmulinteger_dyn_error.onnx", options); }));
}
TEST_CASE(max_test) TEST_CASE(max_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3859,6 +4195,44 @@ TEST_CASE(pad_3arg_test) ...@@ -3859,6 +4195,44 @@ TEST_CASE(pad_3arg_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pad_attr_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4, 2}}});
auto ret = mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 4, 2}, {2, 4, 2}};
auto prog = parse_onnx("pad_attr_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(pad_cnst_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4, 2}}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto ret = mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 2, 0, 1}}}), x);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 4, 2}, {2, 4, 2}};
auto prog = parse_onnx("pad_cnst_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(pad_dyn_reflect_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 4, 2};
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_dyn_reflect_error.onnx", options); }));
}
TEST_CASE(pad_reflect_test) TEST_CASE(pad_reflect_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -4311,6 +4685,50 @@ TEST_CASE(reducel1_test) ...@@ -4311,6 +4685,50 @@ TEST_CASE(reducel1_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reducel1_dyn_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
// a shape with 4 dynamic dimensions
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
{{3, 3, 0}, {3, 5, 0}, {4, 6, 5}, {5, 7, 6}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins);
auto sq_ins = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {-2}}}), sum_ins);
mm->add_return({sq_ins});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, 5}, {5, 7, 6}};
auto prog = migraphx::parse_onnx("reducel1_dyn_test.onnx", options);
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
// No axes given in the onnx file. Parser should default to all axes.
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
{{3, 3, 0}, {3, 5, 0}, {4, 6, 5}, {5, 7, 6}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins);
auto sq_ins =
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2, 3}}}), sum_ins);
mm->add_return({sq_ins});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, 5}, {5, 7, 6}};
auto prog = migraphx::parse_onnx("reducel1_dyn_noaxes_test.onnx", options);
EXPECT(p == prog);
}
}
TEST_CASE(reducel2_test) TEST_CASE(reducel2_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -4361,6 +4779,24 @@ TEST_CASE(reducemax_test) ...@@ -4361,6 +4779,24 @@ TEST_CASE(reducemax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reducemax_dyn_test)
{
// input shape with 4 dynamic dimensions
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}, {6, 6}}});
auto r0 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0);
auto r1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), r0);
mm->add_return({r1});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{3, 3}, {4, 4}, {5, 5}, {6, 6}};
auto prog = migraphx::parse_onnx("reducemax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(reducemean_test) TEST_CASE(reducemean_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5492,6 +5928,23 @@ TEST_CASE(softmax_nonstd_input_test) ...@@ -5492,6 +5928,23 @@ TEST_CASE(softmax_nonstd_input_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(softmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), l0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("softmax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(softplus_test) TEST_CASE(softplus_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5690,6 +6143,29 @@ TEST_CASE(squeeze_unsqueeze_test) ...@@ -5690,6 +6143,29 @@ TEST_CASE(squeeze_unsqueeze_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(squeeze_unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 4, 0}, {1, 1, 0}, {1, 1, 0}, {1, 4, 0}, {1, 1, 0}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto ret = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), c1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("squeeze_unsqueeze_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(squeeze_axes_input_test) TEST_CASE(squeeze_axes_input_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5973,6 +6449,24 @@ TEST_CASE(transpose_test) ...@@ -5973,6 +6449,24 @@ TEST_CASE(transpose_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}});
std::vector<int64_t> perm{0, 3, 1, 2};
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({t0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(topk_attrk_test) TEST_CASE(topk_attrk_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -6058,6 +6552,11 @@ TEST_CASE(transpose_gather_test) ...@@ -6058,6 +6552,11 @@ TEST_CASE(transpose_gather_test)
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
TEST_CASE(trilu_neg_k_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("trilu_neg_k_test.onnx"); }));
}
TEST_CASE(undefined_test) TEST_CASE(undefined_test)
{ {
migraphx::program p; migraphx::program p;
......
trilu_batch_diff_k_test:i

x
ky"Trilutrilu_batch_diff_k_test*
:BkZ
x



b
y



B
\ No newline at end of file
trilu_neg_k_test:c

x
ky"Trilutrilu_neg_k_test*:
BkZ
x


b
y


B
\ No newline at end of file
trilu_out_k_test:Z

x
ky"Trilutrilu_out_k_test*
:BkZ
x


b
y


B
\ No newline at end of file
trilu_row_one_test:\

x
ky"Trilutrilu_row_one_test*
:BkZ
x


b
y


B
\ No newline at end of file

trilu_test:E
xy"Trilu
trilu_testZ
x


b
y


B
\ No newline at end of file
...@@ -451,6 +451,94 @@ TEST_CASE(gather_elements) ...@@ -451,6 +451,94 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(gemm_test)
{
migraphx::program p = migraphx::parse_onnx("gemm_brcst_C_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape a_shape{migraphx::shape::float_type, {5, 6}};
std::vector<float> a_data = {0.26472837, 0.8525864, 0.41929847, 0.14151508, 0.43216065,
0.67468566, 0.42488748, 0.82021785, 0.9782456, 0.5794279,
0.6627283, 0.4790396, 0.9237051, 0.7340607, 0.67379653,
0.87168175, 0.37324256, 0.33278653, 0.42736676, 0.024699844,
0.75851107, 0.48719302, 0.5834426, 0.6938476, 0.43747696,
0.24054702, 0.26912406, 0.6760658, 0.5419149, 0.89949054};
migraphx::shape b_shape{migraphx::shape::float_type, {5, 7}};
std::vector<float> b_data = {
0.65727437, 0.54262096, 0.14126152, 0.8994123, 0.21831702, 0.81191784, 0.9371278,
0.3438551, 0.7121373, 0.90316695, 0.26614252, 0.80144906, 0.80301756, 0.49930334,
0.0719704, 0.63484156, 0.7343097, 0.32130218, 0.7094916, 0.6116475, 0.74144083,
0.021210382, 0.38724765, 0.44830495, 0.62347615, 0.022489505, 0.23316588, 0.76540905,
0.895689, 0.81540287, 0.223875, 0.9275573, 0.4621397, 0.70785195, 0.5658555};
migraphx::shape c_shape{migraphx::shape::float_type, {6, 1}};
std::vector<float> c_data = {
0.07358502, 0.13792239, 0.8574055, 0.40553397, 0.38205826, 0.62062204};
migraphx::parameter_map params;
params["A"] = migraphx::argument(a_shape, a_data.data());
params["B"] = migraphx::argument(b_shape, b_data.data());
params["C"] = migraphx::argument(c_shape, c_data.data());
auto result = p.eval(params).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.45261115, 0.83629227, 0.7533463, 0.7189715, 0.69160205, 0.824082, 0.9187499,
0.6659525, 0.96956736, 0.84293026, 0.8400868, 0.84835225, 1.0982862, 1.0642393,
1.1447254, 1.6184721, 1.6048342, 1.4741788, 1.4334437, 1.638659, 1.7428316,
0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702,
0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703,
1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(gemm_half_test)
{
migraphx::program p = migraphx::parse_onnx("gemm_half_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape a_shape{migraphx::shape::half_type, {8, 6}};
std::vector tmp = {0.2646, 0.8525, 0.4192, 0.1415, 0.4321, 0.675, 0.4248, 0.8203,
0.978, 0.5796, 0.6626, 0.479, 0.924, 0.734, 0.674, 0.8716,
0.3733, 0.3328, 0.4272, 0.0247, 0.7583, 0.4873, 0.5835, 0.694,
0.4375, 0.2406, 0.269, 0.6763, 0.542, 0.8994, 0.657, 0.5425,
0.1412, 0.8994, 0.2183, 0.812, 0.937, 0.3438, 0.712, 0.9033,
0.266, 0.8013, 0.803, 0.4993, 0.07196, 0.635, 0.7344, 0.3213};
std::vector<migraphx::half> a_data{tmp.cbegin(), tmp.cend()};
migraphx::shape b_shape{migraphx::shape::half_type, {8, 7}};
tmp = {0.7095, 0.612, 0.741, 0.02121, 0.3872, 0.4482, 0.6235, 0.02249, 0.2332, 0.7656,
0.8955, 0.8154, 0.2239, 0.9277, 0.4622, 0.708, 0.566, 0.0736, 0.138, 0.8574,
0.4055, 0.382, 0.6206, 0.424, 0.3674, 0.435, 0.998, 0.3594, 0.701, 0.6216,
0.01826, 0.6313, 0.514, 0.1095, 0.3203, 0.01636, 0.537, 0.01952, 0.4502, 0.8965,
0.5415, 0.7456, 0.793, 0.756, 0.9, 0.5264, 0.05368, 0.4214, 0.276, 0.1517,
0.08453, 0.83, 0.417, 0.1682, 0.845, 0.1729};
std::vector<migraphx::half> b_data{tmp.cbegin(), tmp.cend()};
migraphx::shape c_shape{migraphx::shape::half_type, {6, 1}};
tmp = {0.10846, 0.672, 0.527, 0.94, 0.429, 0.2291};
std::vector<migraphx::half> c_data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map params;
params["A"] = migraphx::argument(a_shape, a_data.data());
params["B"] = migraphx::argument(b_shape, b_data.data());
params["C"] = migraphx::argument(c_shape, c_data.data());
auto result = p.eval(params).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {1.071, 1.378, 1.465, 1.093, 0.968, 1.542, 1.145, 1.287, 1.533, 1.75, 1.338,
1.449, 1.592, 1.668, 1.265, 1.531, 1.656, 1.348, 1.2705, 1.525, 1.479, 1.754,
2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155,
1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(greaterorequal_test) TEST_CASE(greaterorequal_test)
{ {
migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx"); migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx");
...@@ -1370,4 +1458,73 @@ TEST_CASE(where_test) ...@@ -1370,4 +1458,73 @@ TEST_CASE(where_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
{
// input data filled with values 1 to nelements
std::vector<float> x_data(s.elements());
std::iota(x_data.begin(), x_data.end(), 1);
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
TEST_CASE(trilu_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(trilu_batch_diff_k_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_batch_diff_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p);
std::vector<float> gold = {0, 0, 3, 0, 0, 0, 0, 0, 9, 0, 0, 0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(trilu_lower_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_lower_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(trilu_out_k_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_out_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold(12, 0);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(trilu_row_one_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_row_one_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p);
std::vector<float> gold = {0, 2, 3, 4};
EXPECT(migraphx::verify_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -81,6 +81,64 @@ void throws_shape(const migraphx::shape&, Ts...) ...@@ -81,6 +81,64 @@ void throws_shape(const migraphx::shape&, Ts...)
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
} }
TEST_CASE(argmax_axis0)
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmax", {{"axis", 0}}),
input);
}
TEST_CASE(argmax_axis1)
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
TEST_CASE(argmax_axis2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
TEST_CASE(argmax_axis_neg)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", -1}}),
input);
}
TEST_CASE(argmax_axis_outofbounds)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
TEST_CASE(argmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {1, 1, 0}, {4, 4, 0}, {5, 5, 0}}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
TEST_CASE(argmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {3, 3, 0}, {1, 1, 0}, {4, 6, 0}}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
TEST_CASE(binary_dyn_static_error) TEST_CASE(binary_dyn_static_error)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
...@@ -409,6 +467,199 @@ TEST_CASE(deconvolution_shape) ...@@ -409,6 +467,199 @@ TEST_CASE(deconvolution_shape)
weights_3d); weights_3d);
} }
TEST_CASE(dot_ndim_error0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error4)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_outer_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_3D_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_3D_test_1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_3D_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_4D_test)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_static_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {8, 8, 0}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_static_mismatch_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5, 0}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_dyn_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {4, 5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, 5}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_mismatch_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_mismatch_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4, 0}, {5, 5, 0}, {2, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
...@@ -437,6 +688,62 @@ TEST_CASE(flatten_shape) ...@@ -437,6 +688,62 @@ TEST_CASE(flatten_shape)
throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input); throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input);
} }
TEST_CASE(flatten_dyn_axis0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {192, 768, 0}}},
migraphx::make_op("flatten", {{"axis", 0}}),
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {192, 768, 0}}},
migraphx::make_op("flatten", {{"axis", -4}}),
input);
}
TEST_CASE(flatten_dyn_axis1)
{
migraphx::shape input{migraphx::shape::float_type,
{{2, 2, 2}, {4, 4, 0}, {4, 6, 5}, {4, 6, 5}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 2, 2}, {4 * 4 * 4, 4 * 6 * 6, 0}}},
migraphx::make_op("flatten", {{"axis", 1}}),
input);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 2, 2}, {4 * 4 * 4, 4 * 6 * 6, 0}}},
migraphx::make_op("flatten", {{"axis", -3}}),
input);
}
TEST_CASE(flatten_dyn_axis2)
{
migraphx::shape input{migraphx::shape::float_type,
{{2, 2, 2}, {4, 4, 0}, {4, 6, 5}, {4, 6, 5}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2 * 4, 2 * 4, 0}, {4 * 4, 6 * 6, 5 * 5}}},
migraphx::make_op("flatten", {{"axis", 2}}),
input);
}
TEST_CASE(flatten_dyn_axis3)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6, 4 * 4 * 6, 0}, {8, 8, 0}}},
migraphx::make_op("flatten", {{"axis", 3}}),
input);
}
TEST_CASE(flatten_dyn_axis4)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1 * 4 * 6 * 8, 4 * 4 * 6 * 8, 0}, {1, 1, 0}}},
migraphx::make_op("flatten", {{"axis", 4}}),
input);
}
TEST_CASE(gather) TEST_CASE(gather)
{ {
{ {
...@@ -524,46 +831,6 @@ TEST_CASE(gather) ...@@ -524,46 +831,6 @@ TEST_CASE(gather)
} }
} }
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
}
TEST_CASE(get_tuple_elem_test) TEST_CASE(get_tuple_elem_test)
{ {
migraphx::shape s0{migraphx::shape::bool_type, {1, 1}}; migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
...@@ -1017,106 +1284,6 @@ TEST_CASE(lstm) ...@@ -1017,106 +1284,6 @@ TEST_CASE(lstm)
} }
} }
// 2 inputs arguments
TEST_CASE(matmul)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
}
TEST_CASE(multibroadcast) TEST_CASE(multibroadcast)
{ {
{ {
...@@ -1549,16 +1716,108 @@ TEST_CASE(nms_shape) ...@@ -1549,16 +1716,108 @@ TEST_CASE(nms_shape)
score_thres_s); score_thres_s);
} }
TEST_CASE(pooling_shape) TEST_CASE(pad_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {2, 3, 5, 5}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pad_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {2, 3, 6, 6}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 2, 2, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pad_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {3, 5, 0}, {3, 5, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {5, 7, 0}, {5, 7, 0}}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pad_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {3, 5, 5}, {3, 5, 5}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {5, 7, 7}, {5, 7, 7}}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pooling_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1}}}),
input);
}
TEST_CASE(pooling_shape1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
input);
}
TEST_CASE(pooling_shape2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(pooling_shape3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {2, 2}},
{"stride", {3, 3}},
{"lengths", {3, 3}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(pooling_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {3, 3, 3}, {3, 3, 0}}};
throws_shape(migraphx::make_op("pooling", throws_shape(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}}, {"padding", {1}},
{"stride", {0}}, {"stride", {0}},
{"lengths", {1}}}), {"lengths", {1}}}),
input); input);
}
TEST_CASE(pooling_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {1, 1, 1}, {1, 1, 0}}};
expect_shape(output, expect_shape(output,
migraphx::make_op("pooling", migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {{"mode", migraphx::op::pooling_mode::max},
...@@ -1566,9 +1825,15 @@ TEST_CASE(pooling_shape) ...@@ -1566,9 +1825,15 @@ TEST_CASE(pooling_shape)
{"stride", {3, 3}}, {"stride", {3, 3}},
{"lengths", {1, 1}}}), {"lengths", {1, 1}}}),
input); input);
}
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}}; TEST_CASE(pooling_dyn_shape2)
expect_shape(output1, {
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {5, 5, 0}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 0}, {5, 5, 0}, {2, 2, 2}, {2, 2, 0}}};
expect_shape(output,
migraphx::make_op("pooling", migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}}, {"padding", {0, 0}},
...@@ -1578,6 +1843,37 @@ TEST_CASE(pooling_shape) ...@@ -1578,6 +1843,37 @@ TEST_CASE(pooling_shape)
input); input);
} }
TEST_CASE(pooling_dyn_shape3)
{
migraphx::shape input{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {4, 12, 8}, {4, 12, 8}}};
migraphx::shape output{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {2, 4, 3}, {2, 4, 3}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
input);
}
TEST_CASE(pooling_dyn_shape4)
{
migraphx::shape input{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {4, 12, 8}, {4, 12, 8}}};
migraphx::shape output{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {3, 6, 4}, {3, 6, 4}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {2, 2}},
{"stride", {3, 3}},
{"lengths", {3, 3}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(prefix_scan_sum) TEST_CASE(prefix_scan_sum)
{ {
{ {
...@@ -1682,9 +1978,51 @@ void test_reduce_ops() ...@@ -1682,9 +1978,51 @@ void test_reduce_ops()
} }
} }
// dynamic shape
template <class T>
void test_dyn_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{2, 3, 3}, {1, 1, 0}})},
T{{-1}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{1, 1, 0}, {2, 4, 4}})},
T{{0}},
input);
}
{
// Empty axis argument reduces all axes
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{1, 1, 0}, {1, 1, 0}})},
T{{}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
throws_shape(T{{4}}, input);
}
}
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); } TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reshape_shape) TEST_CASE(reshape_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
...@@ -1720,6 +2058,55 @@ TEST_CASE(reshape_shape) ...@@ -1720,6 +2058,55 @@ TEST_CASE(reshape_shape)
} }
} }
TEST_CASE(reshape_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
{-1, 1, 1, 24}, {0, 8, 3, 1}, {-1, 3, 4, 2}, {0, 2, 4, 3}})
{
std::vector<migraphx::shape::dynamic_dimension> out_dyn_dims{};
for(std::size_t i = 0; i < new_shape.size(); ++i)
{
if(new_shape[i] == 0 or new_shape[i] == -1)
{
out_dyn_dims.push_back(input.dyn_dims().at(i));
}
else
{
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0});
}
}
migraphx::shape output{migraphx::shape::float_type, out_dyn_dims};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
}
TEST_CASE(reshape_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 20, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 1, 0, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_fixed_ele_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 10, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 1, 5, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_non_fixed_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {2, 1, 1, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
...@@ -1920,6 +2307,20 @@ TEST_CASE(slice_shape) ...@@ -1920,6 +2307,20 @@ TEST_CASE(slice_shape)
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); } TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
TEST_CASE(softmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {5, 8, 6}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
TEST_CASE(test_argmax) TEST_CASE(test_argmax)
{ {
{ {
...@@ -2042,6 +2443,30 @@ TEST_CASE(test_squeeze_all) ...@@ -2042,6 +2443,30 @@ TEST_CASE(test_squeeze_all)
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
} }
TEST_CASE(test_squeeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s3, migraphx::make_op("squeeze"), s1);
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s3, migraphx::make_op("squeeze", {{"axes", {-2, -4}}}), s1);
}
TEST_CASE(test_squeeze_transpose) TEST_CASE(test_squeeze_transpose)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}}; migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}};
...@@ -2083,6 +2508,30 @@ TEST_CASE(test_unsqueeze) ...@@ -2083,6 +2508,30 @@ TEST_CASE(test_unsqueeze)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), s1);
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
TEST_CASE(test_unsqueeze_step) TEST_CASE(test_unsqueeze_step)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
...@@ -2114,13 +2563,27 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis) ...@@ -2114,13 +2563,27 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
} }
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(test_unsqueeze_negative_axis1)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
} }
TEST_CASE(test_unsqueeze_negative_axis2)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 3, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis3)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 5, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-3}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar) TEST_CASE(test_unsqueeze_scalar)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}}; migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
...@@ -2226,6 +2689,28 @@ TEST_CASE(transpose_shape) ...@@ -2226,6 +2689,28 @@ TEST_CASE(transpose_shape)
throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input); throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input);
} }
TEST_CASE(transpose_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{2, 2, 0}, {1, 4, 0}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
}
TEST_CASE(transpose_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4, 0}, {4, 4, 0}, {1, 4, 0}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}
TEST_CASE(transpose_axes_error)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
throws_shape(migraphx::make_op("transpose", {{"permutation", {1}}}), input);
}
TEST_CASE(step_test) TEST_CASE(step_test)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4}}; migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4}};
......
...@@ -94,6 +94,16 @@ def disabled_tests_onnx_1_8_1(backend_test): ...@@ -94,6 +94,16 @@ def disabled_tests_onnx_1_8_1(backend_test):
backend_test.exclude(r'test_unsqueeze_unsorted_axes_cpu') backend_test.exclude(r'test_unsqueeze_unsorted_axes_cpu')
def disabled_tests_onnx_1_10_0(backend_test):
# unsupported shape attributes
backend_test.exclude(r'test_shape_end_1_cpu')
backend_test.exclude(r'test_shape_end_negative_1_cpu')
backend_test.exclude(r'test_shape_start_1_cpu')
backend_test.exclude(r'test_shape_start_1_end_2_cpu')
backend_test.exclude(r'test_shape_start_1_end_negative_1_cpu')
backend_test.exclude(r'test_shape_start_negative_1_cpu')
def create_backend_test(testname=None, target_device=None): def create_backend_test(testname=None, target_device=None):
if target_device is not None: if target_device is not None:
c2.set_device(target_device) c2.set_device(target_device)
...@@ -314,6 +324,9 @@ def create_backend_test(testname=None, target_device=None): ...@@ -314,6 +324,9 @@ def create_backend_test(testname=None, target_device=None):
if version.parse(onnx.__version__) >= version.parse("1.8.0"): if version.parse(onnx.__version__) >= version.parse("1.8.0"):
disabled_tests_onnx_1_8_1(backend_test) disabled_tests_onnx_1_8_1(backend_test)
if version.parse(onnx.__version__) >= version.parse("1.10.0"):
disabled_tests_onnx_1_10_0(backend_test)
# import all test cases at global scope to make # import all test cases at global scope to make
# them visible to python.unittest. # them visible to python.unittest.
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
template <class T> template <class T>
void matmul_test() void dot_2d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -82,11 +82,11 @@ void matmul_test() ...@@ -82,11 +82,11 @@ void matmul_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(matmul_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
template <class T> template <class T>
void matmul_test_ex() void dot_4d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -133,10 +133,10 @@ void matmul_test_ex() ...@@ -133,10 +133,10 @@ void matmul_test_ex()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test_ex<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
TEST_CASE(matmul_mutli_dim_2) TEST_CASE(dot_3D_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2) ...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim_2_beta0) TEST_CASE(dot_3D_C_test0)
{ {
migraphx::program p; migraphx::program p;
...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_beta_0) TEST_CASE(dot_3D_C_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0) ...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(matmul_mutli_dim_2_3) TEST_CASE(dot_4D_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3) ...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim1_2_3) TEST_CASE(dot_4D_alpha_beta_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_3args) TEST_CASE(dot_4D_alpha_beta_C_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args) ...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_3args) TEST_CASE(dot_2D_C_test0)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args) ...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
} }
} }
TEST_CASE(matmul_vv_inner_product) TEST_CASE(dot_vv_inner_product)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
} }
} }
TEST_CASE(matmul_vm) TEST_CASE(dot_vm)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm) ...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
} }
} }
TEST_CASE(matmul_mv) TEST_CASE(dot_mv)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv) ...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
} }
} }
TEST_CASE(matmul_mm1) TEST_CASE(dot_mm1)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1) ...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
} }
} }
TEST_CASE(matmul_mm2) TEST_CASE(dot_mm2)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2) ...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
} }
} }
TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 5}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape, a.data());
params["b"] = migraphx::argument(b_shape, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 5, 3}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape0, a.data());
params["b"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment