Commit 84725d72 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_batch_pass

parents 7f1e8443 bfd77388
c9a53c925510a101f5ca94d5ecda0924e40a8463
This diff is collapsed.
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -840,6 +840,25 @@ TEST_CASE(concat_test) ...@@ -840,6 +840,25 @@ TEST_CASE(concat_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(concat_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {1, 4, 0}, {3, 3, 0}}});
auto l1 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {1, 4, 0}, {3, 3, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("concat"), l0, l1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("concat_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(constant_test) TEST_CASE(constant_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2048,6 +2067,46 @@ TEST_CASE(gather_test) ...@@ -2048,6 +2067,46 @@ TEST_CASE(gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gather_scalar_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
std::vector<size_t> idims{1};
auto l1 =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, idims, {0}});
int axis = 1;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
auto prog = optimize_onnx("gather_scalar_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gather_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"data",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
auto l1 = mm->add_parameter(
"indices",
migraphx::shape{migraphx::shape::int32_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto cont_l0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto cont_l1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
int axis = 1;
auto gather_op = migraphx::make_op("gather", {{"axis", axis}});
auto ret = mm->add_instruction(gather_op, cont_l0, cont_l1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("gather_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gather_elements_axis0_test) TEST_CASE(gather_elements_axis0_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2118,6 +2177,24 @@ TEST_CASE(gathernd_test) ...@@ -2118,6 +2177,24 @@ TEST_CASE(gathernd_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gathernd_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data",
migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4}}});
auto l1 = mm->add_parameter("indices",
migraphx::shape{migraphx::shape::int64_type, {{1, 3}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("gathernd"), l0, l1);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["data"] = {{2, 4, 2}, {2, 4}};
options.map_dyn_input_dims["indices"] = {{1, 3}, {2, 2}};
auto prog = migraphx::parse_onnx("gathernd_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gathernd_batch_dims_test) TEST_CASE(gathernd_batch_dims_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2632,14 +2709,16 @@ TEST_CASE(if_else_test) ...@@ -2632,14 +2709,16 @@ TEST_CASE(if_else_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones); std::vector<float> rand = {1.3865, -0.494756, -0.283504, 0.200491, -0.490031, 1.32388};
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand); auto l1 = mm->add_literal(s, ones);
auto x = mm->add_parameter("x", s); auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
...@@ -2653,15 +2732,32 @@ TEST_CASE(if_else_test) ...@@ -2653,15 +2732,32 @@ TEST_CASE(if_else_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary); auto prog = migraphx::parse_onnx("if_else_test.onnx");
ifs.seekg(0, std::ios::end); EXPECT(p == prog);
auto length = ifs.tellg(); }
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length); TEST_CASE(if_else_test_inlined)
ifs.read(onnx_buffer.data(), length); {
ifs.close(); migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {0.811412, -0.949771, -0.169276, 0.36552, -0.14801, 2.07061};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto re = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({re});
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {}); auto prog = migraphx::parse_onnx("if_else_test_inlined.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -2734,6 +2830,70 @@ TEST_CASE(if_param_test) ...@@ -2734,6 +2830,70 @@ TEST_CASE(if_param_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape s_trail{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s_trail, ones);
std::vector<float> rand = {-1.01837, -0.305541, -0.254105, 0.892955, 1.38714, -0.584205};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s_trail);
mm->add_parameter("y", s);
auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = mm->add_instruction(migraphx::make_op("add"), x, x);
mm->add_return({rt, rt2});
auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape s_trail{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s_trail, ones);
std::vector<float> rand = {-0.753997, 0.707831, -0.865795, 2.49574, 0.464937, -0.168745};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s_trail);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = then_mod->add_instruction(migraphx::make_op("add"), x, x);
then_mod->add_return({rt, rt2});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto re2 = else_mod->add_instruction(migraphx::make_op("sub"), y, l2);
else_mod->add_return({re, re2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r1, r2});
auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2774,14 +2934,16 @@ TEST_CASE(if_then_test) ...@@ -2774,14 +2934,16 @@ TEST_CASE(if_then_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones); std::vector<float> rand = {-0.266913, -0.180328, -0.124268, -1.23768, 0.312334, 1.18475};
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand); auto l1 = mm->add_literal(s, ones);
auto x = mm->add_parameter("x", s); auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
...@@ -2799,6 +2961,32 @@ TEST_CASE(if_then_test) ...@@ -2799,6 +2961,32 @@ TEST_CASE(if_then_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(if_then_test_inlined)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1);
mm->add_return({rt});
auto prog = migraphx::parse_onnx("if_then_test_inlined.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_tuple_test) TEST_CASE(if_tuple_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5599,53 +5787,67 @@ TEST_CASE(scatter_none_test) ...@@ -5599,53 +5787,67 @@ TEST_CASE(scatter_none_test)
TEST_CASE(scatternd_test) TEST_CASE(scatternd_test)
{ {
{ migraphx::program p;
migraphx::program p; auto* mm = p.get_main_module();
auto* mm = p.get_main_module(); auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l0 = auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto l1 = auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); mm->add_return({r});
auto l2 = auto prog = migraphx::parse_onnx("scatternd_test.onnx");
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
{ TEST_CASE(scatternd_dyn_test)
migraphx::program p; {
auto* mm = p.get_main_module(); // dynamic input.
auto l0 = migraphx::program p;
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto* mm = p.get_main_module();
auto l1 = // parameters with dynamic dimensions
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); auto l0 = mm->add_parameter(
auto l2 = "data", migraphx::shape{migraphx::shape::float_type, {{1, 3, 2}, {2, 2}, {2, 2}}});
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); auto l1 = mm->add_parameter(
auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2); "indices", migraphx::shape{migraphx::shape::int64_type, {{2, 1, 2}, {1, 1}, {2, 2}}});
mm->add_return({r}); auto l2 = mm->add_parameter(
auto prog = migraphx::parse_onnx("scatternd_add_test.onnx"); "updates", migraphx::shape{migraphx::shape::float_type, {{2, 1, 2}, {1, 1}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["data"] = {{1, 3, 2}, {2, 2}, {2, 2}};
options.map_dyn_input_dims["indices"] = {{2, 1, 2}, {1, 1}, {2, 2}};
options.map_dyn_input_dims["updates"] = {{2, 1, 2}, {1, 1}, {2, 2}};
auto prog = migraphx::parse_onnx("scatternd_dyn_test.onnx", options);
EXPECT(p == prog); EXPECT(p == prog);
} }
{ TEST_CASE(scatternd_add_test)
migraphx::program p; {
auto* mm = p.get_main_module(); migraphx::program p;
auto l0 = auto* mm = p.get_main_module();
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto l2 = auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2);
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); mm->add_return({r});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2); auto prog = migraphx::parse_onnx("scatternd_add_test.onnx");
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatternd_mul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(selu_test) TEST_CASE(selu_test)
...@@ -5808,6 +6010,44 @@ TEST_CASE(slice_test) ...@@ -5808,6 +6010,44 @@ TEST_CASE(slice_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {1, 3, 0}, {2, 2, 0}}});
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l0);
mm->add_return({ret});
migraphx::onnx_options options;
// Parser converts the dynamic input shape to static unless there is at least one non-fixed
// dynamic dimension. Slicing is not allowed along the non-fixed axis 1.
options.map_dyn_input_dims["0"] = {{3, 3, 0}, {1, 3, 0}, {2, 2, 0}};
auto prog = migraphx::parse_onnx("slice_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_step_dyn_test)
{
// A slice command with non-default steps will have a "Step" instruction added in parsing.
// At the time of writing, Step doesn't support dynamic shape input.
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_step_dyn_test.onnx", options); }));
}
TEST_CASE(slice_reverse_dyn_test)
{
// A slice command with negative step on any axis will have a "Reverse" instruction added in
// parsing. At the time of writing, Reverse doesn't support dynamic shape input.
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_reverse_dyn_test.onnx", options); }));
}
TEST_CASE(slice_3arg_test) TEST_CASE(slice_3arg_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -6765,4 +7005,35 @@ TEST_CASE(where_test) ...@@ -6765,4 +7005,35 @@ TEST_CASE(where_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(where_dyn_test)
{
// TODO: broadcasting for dynamic shapes isn't implemented at time of writing.
// Update this test case to use shapes that require broadcasting, when available.
migraphx::program p;
auto* mm = p.get_main_module();
auto lc = mm->add_parameter(
"c", migraphx::shape{migraphx::shape::bool_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
auto lx = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
auto ly = mm->add_parameter(
"y", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
auto r = mm->add_instruction(migraphx::make_op("where"), lc, lx, ly);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("where_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(where_mixed_test)
{
// mixture of static and dynamic input shapes is not supported
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("where_mixed_test.onnx", options); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -590,6 +590,28 @@ TEST_CASE(if_else_test) ...@@ -590,6 +590,28 @@ TEST_CASE(if_else_test)
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = false;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_else_test_inlined)
{
migraphx::program p = migraphx::parse_onnx("if_else_test_inlined.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data()); pp["x"] = migraphx::argument(s_data, data.data());
...@@ -599,8 +621,49 @@ TEST_CASE(if_else_test) ...@@ -599,8 +621,49 @@ TEST_CASE(if_else_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = { std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472};
-0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557}; EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_test)
{
migraphx::program p = migraphx::parse_onnx("if_then_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = true;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_test_inlined)
{
migraphx::program p = migraphx::parse_onnx("if_then_test_inlined.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -637,6 +700,67 @@ TEST_CASE(if_literal_test) ...@@ -637,6 +700,67 @@ TEST_CASE(if_literal_test)
} }
} }
TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
{
migraphx::program p =
migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape x_data{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape y_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(x_data, data.data());
pp["y"] = migraphx::argument(y_data, data.data());
auto result_args = p.eval(pp);
auto result = result_args.front();
auto result_b = result_args.back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> result_vector_back;
result_b.visit([&](auto output) { result_vector_back.assign(output.begin(), output.end()); });
result_vector.insert(result_vector.end(), result_vector_back.begin(), result_vector_back.end());
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = true;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result_args = p.eval(pp);
auto result = result_args.front();
auto result_b = result_args.back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> result_vector_back;
result_b.visit([&](auto output) { result_vector_back.assign(output.begin(), output.end()); });
result_vector.insert(result_vector.end(), result_vector_back.begin(), result_vector_back.end());
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
auto run_prog = [](bool cond) { auto run_prog = [](bool cond) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment