Commit fe493c28 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents ba0b3794 cce35871
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_promote(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::promote_literals{}, migraphx::dead_code_elimination{}});
}
void run_promote_and_ecs(migraphx::program& p)
{
migraphx::run_passes(p,
{migraphx::promote_literals{},
migraphx::dead_code_elimination{},
migraphx::eliminate_common_subexpression{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(promote_only)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
run_promote(p0);
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins3 = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins2 = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins1 = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins0 = mm1->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size,
migraphx::instruction_ref lit,
const std::string& module_name) {
auto* submod = p1.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), lit, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, literal_ins0, "dim_1");
auto* dim2 = create_submodule(2, literal_ins1, "dim_2");
auto* dim3 = create_submodule(3, literal_ins2, "dim_3");
auto* dim4 = create_submodule(4, literal_ins3, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->insert_parameter(std::next(literal_ins3), "data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm1->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm1->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm1->add_return({ret});
}
EXPECT(p0 == p1);
}
TEST_CASE(promote_and_ecs0)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
run_promote_and_ecs(p0);
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p1.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->insert_parameter(std::next(literal_ins), "data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm1->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm1->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm1->add_return({ret});
}
EXPECT(p0 == p1);
}
TEST_CASE(promote_and_ecs1)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins0 = submod->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins1 = submod->add_literal(migraphx::literal{lit_s, {2}});
auto broadcast_lit0 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins0, sm_input);
auto broadcast_lit1 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins1, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit0);
auto mul_ins =
submod->add_instruction(migraphx::make_op("mul"), add_ins, broadcast_lit1);
submod->add_return({mul_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
run_promote_and_ecs(p0);
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins1 = mm1->add_literal(migraphx::literal{lit_s, {2}});
auto literal_ins0 = mm1->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p1.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit0 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins0, sm_input);
auto broadcast_lit1 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins1, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit0);
auto mul_ins =
submod->add_instruction(migraphx::make_op("mul"), add_ins, broadcast_lit1);
submod->add_return({mul_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->insert_parameter(std::next(literal_ins1), "data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm1->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2});
auto ret =
mm1->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm1->add_return({ret});
}
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -105,6 +105,10 @@ def disabled_tests_onnx_1_10_0(backend_test): ...@@ -105,6 +105,10 @@ def disabled_tests_onnx_1_10_0(backend_test):
backend_test.exclude(r'test_shape_start_negative_1_cpu') backend_test.exclude(r'test_shape_start_negative_1_cpu')
def disabled_tests_onnx_1_12_0(backend_test):
backend_test.exclude(r'test_scatter_elements_with_duplicate_indices_cpu')
def create_backend_test(testname=None, target_device=None): def create_backend_test(testname=None, target_device=None):
if target_device is not None: if target_device is not None:
c2.set_device(target_device) c2.set_device(target_device)
...@@ -328,6 +332,9 @@ def create_backend_test(testname=None, target_device=None): ...@@ -328,6 +332,9 @@ def create_backend_test(testname=None, target_device=None):
if version.parse(onnx.__version__) >= version.parse("1.10.0"): if version.parse(onnx.__version__) >= version.parse("1.10.0"):
disabled_tests_onnx_1_10_0(backend_test) disabled_tests_onnx_1_10_0(backend_test)
if version.parse(onnx.__version__) >= version.parse("1.12.0"):
disabled_tests_onnx_1_12_0(backend_test)
# import all test cases at global scope to make # import all test cases at global scope to make
# them visible to python.unittest. # them visible to python.unittest.
......
...@@ -1197,7 +1197,7 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1197,7 +1197,7 @@ TEST_CASE(dot_dyn_2D_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}}; migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
auto ap = mm->add_parameter("a", a_shape); auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape); auto bp = mm->add_parameter("b", b_shape);
...@@ -1250,8 +1250,7 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1250,8 +1250,7 @@ TEST_CASE(dot_dyn_4D_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, migraphx::shape a_shape{migraphx::shape::float_type, {{1, 1}, {1, 1}, {4, 6, {4}}, {5, 5}}};
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape); auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape); auto bl = mm->add_parameter("b", b_shape);
......
...@@ -64,7 +64,7 @@ TEST_CASE(abs_dyn_test) ...@@ -64,7 +64,7 @@ TEST_CASE(abs_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 8, 0}, {2, 2, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 8}, {2, 2}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("abs"), input); mm->add_instruction(migraphx::make_op("abs"), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -102,7 +102,7 @@ TEST_CASE(acos_dyn_test) ...@@ -102,7 +102,7 @@ TEST_CASE(acos_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("acos"), input); mm->add_instruction(migraphx::make_op("acos"), input);
...@@ -143,7 +143,7 @@ TEST_CASE(acosh_dyn_test) ...@@ -143,7 +143,7 @@ TEST_CASE(acosh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.1f, 1.2f, 2.0f}; std::vector<float> input_data{1.1f, 1.2f, 2.0f};
...@@ -230,7 +230,7 @@ TEST_CASE(add_dyn_test) ...@@ -230,7 +230,7 @@ TEST_CASE(add_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -330,7 +330,7 @@ TEST_CASE(argmax_dyn_test) ...@@ -330,7 +330,7 @@ TEST_CASE(argmax_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 6, 0}, {3, 6, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {3, 6}, {3, 6}}};
auto dl = mm->add_parameter("X", s); auto dl = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl); mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -446,7 +446,7 @@ TEST_CASE(asin_dyn_test) ...@@ -446,7 +446,7 @@ TEST_CASE(asin_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("asin"), input); mm->add_instruction(migraphx::make_op("asin"), input);
...@@ -487,7 +487,7 @@ TEST_CASE(asinh_dyn_test) ...@@ -487,7 +487,7 @@ TEST_CASE(asinh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("asinh"), input); mm->add_instruction(migraphx::make_op("asinh"), input);
...@@ -528,7 +528,7 @@ TEST_CASE(atan_dyn_test) ...@@ -528,7 +528,7 @@ TEST_CASE(atan_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("atan"), input); mm->add_instruction(migraphx::make_op("atan"), input);
...@@ -569,7 +569,7 @@ TEST_CASE(atanh_dyn_test) ...@@ -569,7 +569,7 @@ TEST_CASE(atanh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("atanh"), input); mm->add_instruction(migraphx::make_op("atanh"), input);
...@@ -615,7 +615,7 @@ TEST_CASE(avgpool_dyn_test) ...@@ -615,7 +615,7 @@ TEST_CASE(avgpool_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}}}; auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s); auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {{"mode", migraphx::op::pooling_mode::average},
...@@ -767,7 +767,7 @@ TEST_CASE(broadcast_2in_dyn_test) ...@@ -767,7 +767,7 @@ TEST_CASE(broadcast_2in_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 2, 0}, {2, 4, 0}}}; migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 2}, {2, 4}}};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}}; migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3}; std::vector<int32_t> b_data{-2, -3};
uint64_t axis = 0; uint64_t axis = 0;
...@@ -810,7 +810,7 @@ TEST_CASE(ceil_dyn_test) ...@@ -810,7 +810,7 @@ TEST_CASE(ceil_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 12, 0}; migraphx::shape::dynamic_dimension dd{4, 12};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("ceil"), input); mm->add_instruction(migraphx::make_op("ceil"), input);
...@@ -958,9 +958,9 @@ TEST_CASE(concat_dyn_test) ...@@ -958,9 +958,9 @@ TEST_CASE(concat_dyn_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int axis = 0; int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, 2}, {2, 3, 2}}}; migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2}}, {2, 3, {2}}}};
migraphx::shape s1{migraphx::shape::int32_type, {{3, 4, 4}, {2, 3, 2}}}; migraphx::shape s1{migraphx::shape::int32_type, {{3, 4, {4}}, {2, 3, {2}}}};
migraphx::shape s2{migraphx::shape::int32_type, {{1, 5, 3}, {2, 3, 2}}}; migraphx::shape s2{migraphx::shape::int32_type, {{1, 5, {3}}, {2, 3, {2}}}};
auto input0 = mm->add_parameter("X", s0); auto input0 = mm->add_parameter("X", s0);
auto input1 = mm->add_parameter("Y", s1); auto input1 = mm->add_parameter("Y", s1);
...@@ -1039,8 +1039,7 @@ TEST_CASE(contiguous_dyn_test) ...@@ -1039,8 +1039,7 @@ TEST_CASE(contiguous_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape dyn_shape{migraphx::shape::float_type, migraphx::shape dyn_shape{migraphx::shape::float_type, {{1, 1}, {2, 6}, {2, 2}, {2, 2}}};
{{1, 1, 0}, {2, 6, 0}, {2, 2, 0}, {2, 2, 0}}};
auto input = mm->add_parameter("X", dyn_shape); auto input = mm->add_parameter("X", dyn_shape);
mm->add_instruction(migraphx::make_op("contiguous"), input); mm->add_instruction(migraphx::make_op("contiguous"), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -1068,7 +1067,7 @@ TEST_CASE(conv_dyn_batch_test) ...@@ -1068,7 +1067,7 @@ TEST_CASE(conv_dyn_batch_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type, migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}}; {{1, 100}, {3, 3}, {4, 4}, {4, 4}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape); auto input = mm->add_parameter("X", input_dyn_shape);
...@@ -1184,8 +1183,7 @@ TEST_CASE(conv_dyn_img_shape_test) ...@@ -1184,8 +1183,7 @@ TEST_CASE(conv_dyn_img_shape_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type, migraphx::shape input_dyn_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {4, 6}, {4, 6}}};
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}}; migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape); auto input = mm->add_parameter("X", input_dyn_shape);
...@@ -1274,8 +1272,7 @@ TEST_CASE(conv_dyn_weights_shape_test) ...@@ -1274,8 +1272,7 @@ TEST_CASE(conv_dyn_weights_shape_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}}; migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type, migraphx::shape weights_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 3}, {2, 3}}};
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
auto input = mm->add_parameter("X", input_shape); auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape); auto weights = mm->add_parameter("W", weights_shape);
...@@ -1350,8 +1347,7 @@ TEST_CASE(conv_dyn_img_same_upper_test) ...@@ -1350,8 +1347,7 @@ TEST_CASE(conv_dyn_img_same_upper_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type, migraphx::shape input_dyn_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {4, 6}, {4, 6}}};
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}}; migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape); auto input = mm->add_parameter("X", input_dyn_shape);
...@@ -1422,8 +1418,7 @@ TEST_CASE(conv_dyn_kernel_same_upper_test) ...@@ -1422,8 +1418,7 @@ TEST_CASE(conv_dyn_kernel_same_upper_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}}; migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type, migraphx::shape weights_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 3}, {2, 3}}};
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
auto input = mm->add_parameter("X", input_shape); auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape); auto weights = mm->add_parameter("W", weights_shape);
...@@ -1496,8 +1491,7 @@ TEST_CASE(conv_dyn_kernel_same_lower_test) ...@@ -1496,8 +1491,7 @@ TEST_CASE(conv_dyn_kernel_same_lower_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}}; migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type, migraphx::shape weights_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 3}, {2, 3}}};
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
auto input = mm->add_parameter("X", input_shape); auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape); auto weights = mm->add_parameter("W", weights_shape);
...@@ -1839,7 +1833,7 @@ TEST_CASE(cos_dyn_test) ...@@ -1839,7 +1833,7 @@ TEST_CASE(cos_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("cos"), input); mm->add_instruction(migraphx::make_op("cos"), input);
...@@ -1880,7 +1874,7 @@ TEST_CASE(cosh_dyn_test) ...@@ -1880,7 +1874,7 @@ TEST_CASE(cosh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("cosh"), input); mm->add_instruction(migraphx::make_op("cosh"), input);
...@@ -2071,7 +2065,7 @@ TEST_CASE(div_dyn_test) ...@@ -2071,7 +2065,7 @@ TEST_CASE(div_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 3}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {3}}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -2113,7 +2107,7 @@ TEST_CASE(elu_dyn_test) ...@@ -2113,7 +2107,7 @@ TEST_CASE(elu_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
float alpha = 0.5; float alpha = 0.5;
...@@ -2184,7 +2178,7 @@ TEST_CASE(equal_dyn_test) ...@@ -2184,7 +2178,7 @@ TEST_CASE(equal_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{6, 12, 9}}; std::vector<migraphx::shape::dynamic_dimension> dd{{6, 12, {9}}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto p0 = mm->add_parameter("l", s); auto p0 = mm->add_parameter("l", s);
auto p1 = mm->add_parameter("r", s); auto p1 = mm->add_parameter("r", s);
...@@ -2231,7 +2225,7 @@ TEST_CASE(erf_dyn_test) ...@@ -2231,7 +2225,7 @@ TEST_CASE(erf_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("erf"), input); mm->add_instruction(migraphx::make_op("erf"), input);
...@@ -2272,7 +2266,7 @@ TEST_CASE(exp_dyn_test) ...@@ -2272,7 +2266,7 @@ TEST_CASE(exp_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("exp"), input); mm->add_instruction(migraphx::make_op("exp"), input);
...@@ -2313,7 +2307,7 @@ TEST_CASE(floor_dyn_test) ...@@ -2313,7 +2307,7 @@ TEST_CASE(floor_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 12, 0}; migraphx::shape::dynamic_dimension dd{5, 12};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("floor"), input); mm->add_instruction(migraphx::make_op("floor"), input);
...@@ -2564,7 +2558,7 @@ TEST_CASE(gather_dyn_test0) ...@@ -2564,7 +2558,7 @@ TEST_CASE(gather_dyn_test0)
// Dynamic data, static indices // Dynamic data, static indices
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::int32_type, {{2, 5}, {3, 3}}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
std::vector<int> indices{1, 2}; std::vector<int> indices{1, 2};
...@@ -2573,7 +2567,7 @@ TEST_CASE(gather_dyn_test0) ...@@ -2573,7 +2567,7 @@ TEST_CASE(gather_dyn_test0)
auto ind = mm->add_parameter("indices", s_ind); auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, ind); mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{2, 5, 0}, {1, 1, 0}, {2, 2, 0}}}; migraphx::shape sresult{migraphx::shape::int32_type, {{2, 5}, {1, 1}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == sresult); EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -2599,15 +2593,15 @@ TEST_CASE(gather_dyn_test1) ...@@ -2599,15 +2593,15 @@ TEST_CASE(gather_dyn_test1)
// Dynamic data, dynamic indices // Dynamic data, dynamic indices
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {4, 4, 0}}}; migraphx::shape s{migraphx::shape::int32_type, {{2, 5}, {4, 4}}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::shape s_ind{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}}}; migraphx::shape s_ind{migraphx::shape::int32_type, {{1, 8, {7}}, {2, 3, {3}}}};
auto ind = mm->add_parameter("indices", s_ind); auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, ind); mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}, {4, 4, 0}}}; migraphx::shape sresult{migraphx::shape::int32_type, {{1, 8, {7}}, {2, 3, {3}}, {4, 4}}};
EXPECT(p.get_output_shapes().back() == sresult); EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -2787,7 +2781,7 @@ TEST_CASE(gathernd_dynamic0) ...@@ -2787,7 +2781,7 @@ TEST_CASE(gathernd_dynamic0)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 2, 2}, {3, 3, 0}, {1, 1, 0}}}; migraphx::shape ds{migraphx::shape::float_type, {{2, 2, {2}}, {3, 3}, {1, 1}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}}; migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds); auto xdata = mm->add_parameter("X", ds);
...@@ -2824,7 +2818,7 @@ TEST_CASE(gathernd_dynamic1) ...@@ -2824,7 +2818,7 @@ TEST_CASE(gathernd_dynamic1)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}}; migraphx::shape ds{migraphx::shape::float_type, {{2, 5, {2}}, {1, 5}, {1, 5}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}}; migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds); auto xdata = mm->add_parameter("X", ds);
...@@ -2860,8 +2854,8 @@ TEST_CASE(gathernd_dynamic2) ...@@ -2860,8 +2854,8 @@ TEST_CASE(gathernd_dynamic2)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}}; migraphx::shape ds{migraphx::shape::float_type, {{2, 5, {2}}, {1, 5}, {1, 5}}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}}; migraphx::shape is{migraphx::shape::int64_type, {{2, 5, {3}}, {2, 3, {3}}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds); auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is); auto xindex = mm->add_parameter("I", is);
...@@ -2897,7 +2891,7 @@ TEST_CASE(gathernd_dynamic3) ...@@ -2897,7 +2891,7 @@ TEST_CASE(gathernd_dynamic3)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}}; migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}}; migraphx::shape is{migraphx::shape::int64_type, {{2, 5, {3}}, {2, 3, {3}}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds); auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is); auto xindex = mm->add_parameter("I", is);
...@@ -2932,8 +2926,7 @@ TEST_CASE(gathernd_dynamic4) ...@@ -2932,8 +2926,7 @@ TEST_CASE(gathernd_dynamic4)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, migraphx::shape ds{migraphx::shape::float_type, {migraphx::shape::dynamic_dimension({2, 2})}};
{migraphx::shape::dynamic_dimension({2, 2, 0})}};
migraphx::shape is{migraphx::shape::int64_type, {1}}; migraphx::shape is{migraphx::shape::int64_type, {1}};
auto xdata = mm->add_parameter("X", ds); auto xdata = mm->add_parameter("X", ds);
...@@ -3034,9 +3027,8 @@ TEST_CASE(globalavgpool_dyn_test) ...@@ -3034,9 +3027,8 @@ TEST_CASE(globalavgpool_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = auto s = migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6}, {2, 6, {2}}}};
migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 6, 0}, {2, 6, 2}}}; auto x = mm->add_parameter("X", s);
auto x = mm->add_parameter("X", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("pooling", migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {"dyn_global", true}}), {{"mode", migraphx::op::pooling_mode::average}, {"dyn_global", true}}),
...@@ -3081,7 +3073,7 @@ TEST_CASE(globallppool_dyn_test) ...@@ -3081,7 +3073,7 @@ TEST_CASE(globallppool_dyn_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 6, 2}, {2, 6, 2}}}; migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s); auto x = mm->add_parameter("X", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("pooling", migraphx::make_op("pooling",
...@@ -3126,7 +3118,7 @@ TEST_CASE(globalmaxpool_dyn_test) ...@@ -3126,7 +3118,7 @@ TEST_CASE(globalmaxpool_dyn_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 6, 2}, {2, 6, 2}}}; migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s); auto x = mm->add_parameter("X", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("pooling", migraphx::make_op("pooling",
...@@ -3198,7 +3190,7 @@ TEST_CASE(greater_dyn_test) ...@@ -3198,7 +3190,7 @@ TEST_CASE(greater_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, 9}}; std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, {9}}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto left = mm->add_parameter("l", s); auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s); auto right = mm->add_parameter("r", s);
...@@ -3242,7 +3234,7 @@ TEST_CASE(identity_dyn_test) ...@@ -3242,7 +3234,7 @@ TEST_CASE(identity_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 4, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 4}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("identity"), input); mm->add_instruction(migraphx::make_op("identity"), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -3488,7 +3480,7 @@ TEST_CASE(isnan_dyn_test) ...@@ -3488,7 +3480,7 @@ TEST_CASE(isnan_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 8, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {3, 8}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
auto nan_val = std::numeric_limits<float>::quiet_NaN(); auto nan_val = std::numeric_limits<float>::quiet_NaN();
mm->add_instruction(migraphx::make_op("isnan"), input); mm->add_instruction(migraphx::make_op("isnan"), input);
...@@ -3807,7 +3799,7 @@ TEST_CASE(less_dyn_test) ...@@ -3807,7 +3799,7 @@ TEST_CASE(less_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, 9}}; std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, {9}}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto left = mm->add_parameter("l", s); auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s); auto right = mm->add_parameter("r", s);
...@@ -3859,7 +3851,7 @@ TEST_CASE(log_dyn_test) ...@@ -3859,7 +3851,7 @@ TEST_CASE(log_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("log"), input); mm->add_instruction(migraphx::make_op("log"), input);
...@@ -3904,7 +3896,7 @@ TEST_CASE(logical_and_dyn_test) ...@@ -3904,7 +3896,7 @@ TEST_CASE(logical_and_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {4}}};
migraphx::shape s{migraphx::shape::bool_type, dd}; migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s); auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s); auto right = mm->add_parameter("r", s);
...@@ -3955,7 +3947,7 @@ TEST_CASE(logical_or_dyn_test) ...@@ -3955,7 +3947,7 @@ TEST_CASE(logical_or_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {4}}};
migraphx::shape s{migraphx::shape::bool_type, dd}; migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s); auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s); auto right = mm->add_parameter("r", s);
...@@ -4006,7 +3998,7 @@ TEST_CASE(logical_xor_dyn_test) ...@@ -4006,7 +3998,7 @@ TEST_CASE(logical_xor_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {4}}};
migraphx::shape s{migraphx::shape::bool_type, dd}; migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s); auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s); auto right = mm->add_parameter("r", s);
...@@ -4227,7 +4219,7 @@ TEST_CASE(lppool_dyn_test) ...@@ -4227,7 +4219,7 @@ TEST_CASE(lppool_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}}}; auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s); auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm}, {{"mode", migraphx::op::pooling_mode::lpnorm},
...@@ -4294,7 +4286,7 @@ TEST_CASE(max_dyn_test) ...@@ -4294,7 +4286,7 @@ TEST_CASE(max_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -4497,7 +4489,7 @@ TEST_CASE(maxpool_dyn_test) ...@@ -4497,7 +4489,7 @@ TEST_CASE(maxpool_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}}}; auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s); auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {{"mode", migraphx::op::pooling_mode::max},
...@@ -4540,7 +4532,7 @@ TEST_CASE(min_dyn_test) ...@@ -4540,7 +4532,7 @@ TEST_CASE(min_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -4586,7 +4578,7 @@ TEST_CASE(fmod_dyn_test) ...@@ -4586,7 +4578,7 @@ TEST_CASE(fmod_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -4651,7 +4643,7 @@ TEST_CASE(mod_dyn_test) ...@@ -4651,7 +4643,7 @@ TEST_CASE(mod_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -4720,7 +4712,7 @@ TEST_CASE(mul_dyn_test) ...@@ -4720,7 +4712,7 @@ TEST_CASE(mul_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -4790,7 +4782,7 @@ TEST_CASE(multibroadcast_2in_dyn_test) ...@@ -4790,7 +4782,7 @@ TEST_CASE(multibroadcast_2in_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 4, 0}, {2, 2, 0}}}; migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 4}, {2, 2}}};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}}; migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3}; std::vector<int32_t> b_data{-2, -3};
auto l1 = mm->add_parameter("a", a_shape); auto l1 = mm->add_parameter("a", a_shape);
...@@ -4882,7 +4874,7 @@ TEST_CASE(neg_dyn_test) ...@@ -4882,7 +4874,7 @@ TEST_CASE(neg_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {3, 3}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
auto ret = mm->add_instruction(migraphx::make_op("neg"), input); auto ret = mm->add_instruction(migraphx::make_op("neg"), input);
mm->add_return({ret}); mm->add_return({ret});
...@@ -4939,9 +4931,9 @@ TEST_CASE(nms_dyn_batch_test) ...@@ -4939,9 +4931,9 @@ TEST_CASE(nms_dyn_batch_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 3, 0}, {6, 6, 0}, {4, 4, 0}}}; migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 3}, {6, 6}, {4, 4}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 3, 0}, {1, 1, 0}, {6, 6, 0}}}; migraphx::shape scores_s{migraphx::shape::float_type, {{1, 3}, {1, 1}, {6, 6}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s); auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s); auto scores_p = mm->add_parameter("scores", scores_s);
...@@ -4985,9 +4977,9 @@ TEST_CASE(nms_dyn_boxes_test) ...@@ -4985,9 +4977,9 @@ TEST_CASE(nms_dyn_boxes_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1, 0}, {4, 20, 0}, {4, 4, 0}}}; migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1}, {4, 20}, {4, 4}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {4, 20, 0}}}; migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1}, {1, 1}, {4, 20}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s); auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s); auto scores_p = mm->add_parameter("scores", scores_s);
...@@ -5028,9 +5020,9 @@ TEST_CASE(nms_dyn_classes_test) ...@@ -5028,9 +5020,9 @@ TEST_CASE(nms_dyn_classes_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1, 0}, {6, 6, 0}, {4, 4, 0}}}; migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1}, {6, 6}, {4, 4}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {6, 6, 0}}}; migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1}, {1, 3}, {6, 6}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s); auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s); auto scores_p = mm->add_parameter("scores", scores_s);
...@@ -5274,7 +5266,7 @@ TEST_CASE(not_dyn_test) ...@@ -5274,7 +5266,7 @@ TEST_CASE(not_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("not"), input); mm->add_instruction(migraphx::make_op("not"), input);
...@@ -5363,7 +5355,7 @@ TEST_CASE(pad_dyn_test) ...@@ -5363,7 +5355,7 @@ TEST_CASE(pad_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 2}, {2, 4, 2}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x); mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -5827,7 +5819,7 @@ TEST_CASE(prelu_dyn_test) ...@@ -5827,7 +5819,7 @@ TEST_CASE(prelu_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto slope = mm->add_parameter("slope", s); auto slope = mm->add_parameter("slope", s);
...@@ -6028,7 +6020,7 @@ TEST_CASE(recip_dyn_test) ...@@ -6028,7 +6020,7 @@ TEST_CASE(recip_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("recip"), input); mm->add_instruction(migraphx::make_op("recip"), input);
...@@ -6065,7 +6057,7 @@ TEST_CASE(reduce_max_dynamic_axis0) ...@@ -6065,7 +6057,7 @@ TEST_CASE(reduce_max_dynamic_axis0)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 2}, {3, 5, 3}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2}}, {3, 5, {3}}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
auto reduce_max_op = migraphx::make_op("reduce_max", {{"axes", {0}}}); auto reduce_max_op = migraphx::make_op("reduce_max", {{"axes", {0}}});
mm->add_instruction(reduce_max_op, input); mm->add_instruction(reduce_max_op, input);
...@@ -6357,7 +6349,7 @@ TEST_CASE(relu_dyn_test) ...@@ -6357,7 +6349,7 @@ TEST_CASE(relu_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("relu"), input); mm->add_instruction(migraphx::make_op("relu"), input);
...@@ -6429,7 +6421,7 @@ TEST_CASE(reshape_dyn_test) ...@@ -6429,7 +6421,7 @@ TEST_CASE(reshape_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 8, 3, 1}; std::vector<int64_t> new_shape = {0, 8, 3, 1};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input); mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
...@@ -6691,7 +6683,7 @@ TEST_CASE(round_dyn_test) ...@@ -6691,7 +6683,7 @@ TEST_CASE(round_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 10, 0}; migraphx::shape::dynamic_dimension dd{4, 10};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("round"), input); mm->add_instruction(migraphx::make_op("round"), input);
...@@ -6727,7 +6719,7 @@ TEST_CASE(rsqrt_dyn_test) ...@@ -6727,7 +6719,7 @@ TEST_CASE(rsqrt_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("rsqrt"), input); mm->add_instruction(migraphx::make_op("rsqrt"), input);
...@@ -7466,10 +7458,10 @@ TEST_CASE(scatternd_reduction_dyn_test) ...@@ -7466,10 +7458,10 @@ TEST_CASE(scatternd_reduction_dyn_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type; auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type; auto itype = migraphx::shape::int64_type;
migraphx::shape::dynamic_dimension dd{3, 6, 0}; migraphx::shape::dynamic_dimension dd{3, 6};
migraphx::shape ds{migraphx::shape::float_type, {dd, dd, dd}}; migraphx::shape ds{migraphx::shape::float_type, {dd, dd, dd}};
migraphx::shape is{itype, {2, 1}}; migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {{2, 2, 0}, dd, dd}}; migraphx::shape us{dtype, {{2, 2}, dd, dd}};
auto xdata = mm->add_parameter("X", ds); auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is); auto xindex = mm->add_parameter("I", is);
...@@ -7523,7 +7515,7 @@ TEST_CASE(sigmoid_dyn_test) ...@@ -7523,7 +7515,7 @@ TEST_CASE(sigmoid_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 2, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 2}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sigmoid"), input); mm->add_instruction(migraphx::make_op("sigmoid"), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -7559,7 +7551,7 @@ TEST_CASE(sign_dyn_test) ...@@ -7559,7 +7551,7 @@ TEST_CASE(sign_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sign"), input); mm->add_instruction(migraphx::make_op("sign"), input);
...@@ -7598,7 +7590,7 @@ TEST_CASE(sin_dyn_test) ...@@ -7598,7 +7590,7 @@ TEST_CASE(sin_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sin"), input); mm->add_instruction(migraphx::make_op("sin"), input);
...@@ -7639,7 +7631,7 @@ TEST_CASE(sinh_dynamic_test) ...@@ -7639,7 +7631,7 @@ TEST_CASE(sinh_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 4, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 4}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0}; std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
mm->add_instruction(migraphx::make_op("sinh"), input); mm->add_instruction(migraphx::make_op("sinh"), input);
...@@ -7705,15 +7697,15 @@ TEST_CASE(slice_test) ...@@ -7705,15 +7697,15 @@ TEST_CASE(slice_test)
TEST_CASE(slice_dyn_test0) TEST_CASE(slice_dyn_test0)
{ {
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is too // Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
// large // too large
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::int32_type, {{2, 3}, {2, 2}, {3, 3}}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {0, 1}}, {"ends", {1, 6}}}), x); migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {0, 1}}, {"ends", {1, 6}}}), x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 3, 0}, {1, 1, 0}, {2, 2, 0}}}; migraphx::shape s2{migraphx::shape::int32_type, {{2, 3}, {1, 1}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == s2); EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -7740,14 +7732,14 @@ TEST_CASE(slice_dyn_test1) ...@@ -7740,14 +7732,14 @@ TEST_CASE(slice_dyn_test1)
// Slice all three dynamic dimensions // Slice all three dynamic dimensions
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::int32_type, {{2, 2}, {2, 2}, {3, 3}}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("slice", migraphx::make_op("slice",
{{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}), {{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}),
x); x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 2, 0}, {2, 2, 0}, {2, 2, 0}}}; migraphx::shape s2{migraphx::shape::int32_type, {{2, 2}, {2, 2}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == s2); EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
...@@ -7847,7 +7839,7 @@ TEST_CASE(softmax_dyn_test) ...@@ -7847,7 +7839,7 @@ TEST_CASE(softmax_dyn_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 10, 0}, {1, 3, 3}, {4, 4, 0}, {2, 2, 2}}}; {{1, 10}, {1, 3, {3}}, {4, 4}, {2, 2, {2}}}};
auto al = mm->add_parameter("a", a_shape); auto al = mm->add_parameter("a", a_shape);
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al); mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -7925,7 +7917,7 @@ TEST_CASE(sqdiff_dyn_test) ...@@ -7925,7 +7917,7 @@ TEST_CASE(sqdiff_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -7967,7 +7959,7 @@ TEST_CASE(sqrt_dynamic_test) ...@@ -7967,7 +7959,7 @@ TEST_CASE(sqrt_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}; std::vector<float> input_data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184};
...@@ -8031,8 +8023,7 @@ TEST_CASE(squeeze_dyn_test) ...@@ -8031,8 +8023,7 @@ TEST_CASE(squeeze_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
auto p0 = mm->add_parameter("x", s1); auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), p0); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -8103,7 +8094,7 @@ TEST_CASE(sub_dyn_test) ...@@ -8103,7 +8094,7 @@ TEST_CASE(sub_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
...@@ -8145,7 +8136,7 @@ TEST_CASE(tan_dynamic_test) ...@@ -8145,7 +8136,7 @@ TEST_CASE(tan_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1, 0, 1}; std::vector<float> input_data{-1, 0, 1};
...@@ -8186,7 +8177,7 @@ TEST_CASE(tanh_dynamic_test) ...@@ -8186,7 +8177,7 @@ TEST_CASE(tanh_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0}; std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
...@@ -8294,7 +8285,7 @@ TEST_CASE(transpose_dyn_test) ...@@ -8294,7 +8285,7 @@ TEST_CASE(transpose_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}, {3, 3}}};
auto l = mm->add_parameter("X", s); auto l = mm->add_parameter("X", s);
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
...@@ -8349,7 +8340,7 @@ TEST_CASE(unsqueeze_dyn_test) ...@@ -8349,7 +8340,7 @@ TEST_CASE(unsqueeze_dyn_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}}; migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
auto p0 = mm->add_parameter("x", s1); auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -8394,8 +8385,8 @@ TEST_CASE(where_dyn_test) ...@@ -8394,8 +8385,8 @@ TEST_CASE(where_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {{2, 3, 0}, {2, 3, 0}}}; migraphx::shape sb{migraphx::shape::bool_type, {{2, 3}, {2, 3}}};
migraphx::shape sx{migraphx::shape::float_type, {{2, 3, 0}, {2, 3, 0}}}; migraphx::shape sx{migraphx::shape::float_type, {{2, 3}, {2, 3}}};
auto lb = mm->add_parameter("predicate", sb); auto lb = mm->add_parameter("predicate", sb);
auto lx = mm->add_parameter("X", sx); auto lx = mm->add_parameter("X", sx);
......
...@@ -33,12 +33,20 @@ ...@@ -33,12 +33,20 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/verify.hpp>
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }
migraphx::argument eval(const migraphx::program& p)
{
auto r = p.eval({});
EXPECT(r.size() == 1);
return r.front();
}
TEST_CASE(quantizelinear) TEST_CASE(quantizelinear)
{ {
...@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear) ...@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt; run_pass(*p2.get_main_module());
opt.apply(*p2.get_main_module()); EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
} }
...@@ -71,9 +79,9 @@ TEST_CASE(dequantizelinear) ...@@ -71,9 +79,9 @@ TEST_CASE(dequantizelinear)
std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250}; std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250};
migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2};
migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}}; migraphx::shape zs{migraphx::shape::float_type, {1, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<float> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() { auto create_program = [&]() {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv); auto x = mm->add_literal(xs, xv);
...@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear) ...@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt; run_pass(*p2.get_main_module());
opt.apply(*p2.get_main_module()); EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear));
} }
......
...@@ -41,22 +41,13 @@ TEST_CASE(test_shape_default) ...@@ -41,22 +41,13 @@ TEST_CASE(test_shape_default)
TEST_CASE(test_dyn_4arg_constructor) TEST_CASE(test_dyn_4arg_constructor)
{ {
migraphx::shape s{migraphx::shape::float_type, migraphx::shape s0{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {{}, {}, {}}};
{ migraphx::shape s1{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {}};
1, std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {{1, 4}, {4, 4}, {4, 4}};
4, EXPECT(s0.dynamic());
4, EXPECT(s0.dyn_dims() == expected_dyn_dims);
}, EXPECT(s1.dynamic());
{ EXPECT(s1.dyn_dims() == expected_dyn_dims);
4,
4,
4,
},
{0, 0, 0}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {
{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
EXPECT(s.dynamic());
EXPECT(s.dyn_dims() == expected_dyn_dims);
} }
TEST_CASE(test_shape_assign) TEST_CASE(test_shape_assign)
...@@ -85,17 +76,26 @@ TEST_CASE(test_shape_standard) ...@@ -85,17 +76,26 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_standard_singleton_dim)
{
migraphx::shape s{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_min_max_opt) TEST_CASE(test_shape_min_max_opt)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.min_lens() == s.lens()); EXPECT(s.min_lens() == s.lens());
EXPECT(s.max_lens() == s.lens()); EXPECT(s.max_lens() == s.lens());
EXPECT(s.opt_lens() == s.lens()); EXPECT(s.opt_lens().empty());
} }
TEST_CASE(test_shape_dynamic_fixed) TEST_CASE(test_shape_dynamic_fixed)
{ {
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {2, 2}, {3, 3}}};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -106,7 +106,8 @@ TEST_CASE(test_shape_dynamic_fixed) ...@@ -106,7 +106,8 @@ TEST_CASE(test_shape_dynamic_fixed)
EXPECT(not s.dyn_dims().at(0).has_optimal()); EXPECT(not s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3}); EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3}); EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.opt_lens() == std::vector<std::size_t>{0, 0, 0}); std::vector<std::set<std::size_t>> e_opt_lens = {{}, {}, {}};
EXPECT(s.opt_lens() == e_opt_lens);
EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float)); EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float));
} }
...@@ -114,8 +115,8 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -114,8 +115,8 @@ TEST_CASE(test_shape_dynamic_not_fixed)
{ {
using migraphx::shape; using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {}; std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2}); dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8, 0}); dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
...@@ -127,18 +128,16 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -127,18 +128,16 @@ TEST_CASE(test_shape_dynamic_not_fixed)
EXPECT(s.dyn_dims().at(0).has_optimal()); EXPECT(s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2}); EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2});
EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8}); EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8});
EXPECT(s.opt_lens() == std::vector<std::size_t>{2, 0}); EXPECT(s.opt_lens() == std::vector<std::set<std::size_t>>{{2}, {}});
EXPECT(s.bytes() == 5 * 8 * sizeof(float)); EXPECT(s.bytes() == 5 * 8 * sizeof(float));
} }
TEST_CASE(test_shape_dynamic_compares) TEST_CASE(test_shape_dynamic_compares)
{ {
using migraphx::shape; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2}; auto a = shape::dynamic_dimension{2, 5, {2}};
auto b = a; auto c = shape::dynamic_dimension{2, 5, {2}};
auto c = shape::dynamic_dimension{2, 5, 2}; auto d = shape::dynamic_dimension{3, 8};
auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
EXPECT(a == c); EXPECT(a == c);
EXPECT(a != d); EXPECT(a != d);
...@@ -163,13 +162,13 @@ TEST_CASE(test_shape_dynamic_compares) ...@@ -163,13 +162,13 @@ TEST_CASE(test_shape_dynamic_compares)
TEST_CASE(dynamic_dimension_size_t_compares) TEST_CASE(dynamic_dimension_size_t_compares)
{ {
using migraphx::shape; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2}; auto a = shape::dynamic_dimension{2, 2, {2}};
EXPECT(a == 2); EXPECT(a == 2);
EXPECT(a != 3); EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a); EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a); EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0}; auto b = shape::dynamic_dimension{2, 4};
EXPECT(b != 2); EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b); EXPECT(static_cast<std::size_t>(2) != b);
} }
...@@ -177,25 +176,25 @@ TEST_CASE(dynamic_dimension_size_t_compares) ...@@ -177,25 +176,25 @@ TEST_CASE(dynamic_dimension_size_t_compares)
TEST_CASE(dynamic_dimension_add_sub_fixed) TEST_CASE(dynamic_dimension_add_sub_fixed)
{ {
using migraphx::shape; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2}; auto a = shape::dynamic_dimension{2, 5, {2}};
a += 3; a += 3;
EXPECT(a == shape::dynamic_dimension{5, 8, 5}); EXPECT(a == shape::dynamic_dimension{5, 8, {5}});
a -= 3; a -= 3;
EXPECT(a == shape::dynamic_dimension{2, 5, 2}); EXPECT(a == shape::dynamic_dimension{2, 5, {2}});
auto b = shape::dynamic_dimension{3, 6, 3}; auto b = shape::dynamic_dimension{3, 6, {3}};
EXPECT((a + 1) == b); EXPECT((a + 1) == b);
EXPECT((1 + a) == b); EXPECT((1 + a) == b);
EXPECT((b - 1) == a); EXPECT((b - 1) == a);
auto c = shape::dynamic_dimension{4, 7, 4}; auto c = shape::dynamic_dimension{4, 7, {4}};
EXPECT((a + 2) == c); EXPECT((a + 2) == c);
EXPECT((2 + a) == c); EXPECT((2 + a) == c);
EXPECT((c - 2) == a); EXPECT((c - 2) == a);
auto d = shape::dynamic_dimension{4, 8, 0}; auto d = shape::dynamic_dimension{4, 8};
auto e = shape::dynamic_dimension{2, 6, 0}; auto e = shape::dynamic_dimension{2, 6};
EXPECT((d - 2) == e); EXPECT((d - 2) == e);
EXPECT((e + 2) == d); EXPECT((e + 2) == d);
EXPECT((2 + e) == d); EXPECT((2 + e) == d);
...@@ -205,8 +204,8 @@ TEST_CASE(test_shape_dynamic_errors) ...@@ -205,8 +204,8 @@ TEST_CASE(test_shape_dynamic_errors)
{ {
using migraphx::shape; using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {}; std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2}); dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8, 0}); dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{shape::float_type, dims}; migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); })); EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.index({0, 1}); })); EXPECT(test::throws([&] { s.index({0, 1}); }));
...@@ -220,13 +219,13 @@ TEST_CASE(test_shape_dynamic_serialize) ...@@ -220,13 +219,13 @@ TEST_CASE(test_shape_dynamic_serialize)
{ {
using migraphx::shape; using migraphx::shape;
std::vector<shape::dynamic_dimension> dims1 = {}; std::vector<shape::dynamic_dimension> dims1 = {};
dims1.push_back(shape::dynamic_dimension{2, 5, 2}); dims1.push_back(shape::dynamic_dimension{2, 5, {2}});
dims1.push_back(shape::dynamic_dimension{2, 8, 0}); dims1.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s1{shape::float_type, dims1}; migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1); auto v1 = migraphx::to_value(s1);
std::vector<shape::dynamic_dimension> dims2 = {}; std::vector<shape::dynamic_dimension> dims2 = {};
dims2.push_back(shape::dynamic_dimension{2, 5, 2}); dims2.push_back(shape::dynamic_dimension{2, 5, {2}});
migraphx::shape s2{shape::uint64_type, dims2}; migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2); auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2); EXPECT(v1 != v2);
...@@ -285,14 +284,13 @@ TEST_CASE(test_shape_ndim_static) ...@@ -285,14 +284,13 @@ TEST_CASE(test_shape_ndim_static)
TEST_CASE(test_shape_ndim_dyn) TEST_CASE(test_shape_ndim_dyn)
{ {
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}}; migraphx::shape s0{migraphx::shape::float_type, {{2, 2}, {2, 2}}};
EXPECT(s0.ndim() == 2); EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}}; migraphx::shape s1{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(s1.ndim() == 4); EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type, migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {1, 1}, {3, 3}}};
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
EXPECT(s2.ndim() == 5); EXPECT(s2.ndim() == 5);
} }
...@@ -327,17 +325,60 @@ TEST_CASE(test_shape_static_to_dynamic) ...@@ -327,17 +325,60 @@ TEST_CASE(test_shape_static_to_dynamic)
{ {
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic(); migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}}; migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}};
EXPECT(s1 == s2); EXPECT(s1 == s2);
} }
TEST_CASE(test_shape_dyn_to_dynamic) TEST_CASE(test_shape_dyn_to_dynamic)
{ {
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}}; migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
migraphx::shape s1 = s0.to_dynamic(); migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1); EXPECT(s0 == s1);
} }
TEST_CASE(test_shape_subshapes_to_dynamic)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_dynamic();
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 2}, {2, 10}, {2, 10}}};
migraphx::shape s1 = s0.to_static(4);
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 4, 4}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_static_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_static(8);
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_subshapes_to_static)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_static(3);
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_overlap) TEST_CASE(test_shape_overlap)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
......
...@@ -509,6 +509,34 @@ TEST_CASE(simplify_dot_add) ...@@ -509,6 +509,34 @@ TEST_CASE(simplify_dot_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_conv_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto c = m1.add_literal(migraphx::generate_literal(s, 1));
auto w = m1.add_literal(migraphx::generate_literal(ws, 2));
auto sum = m1.add_instruction(migraphx::make_op("add"), c, x);
auto conv = m1.add_instruction(migraphx::make_op("convolution"), sum, w);
m1.add_instruction(pass_op{}, conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto c = m2.add_literal(migraphx::generate_literal(s, 1));
auto w = m2.add_literal(migraphx::generate_literal(ws, 2));
auto conv1 = m2.add_instruction(migraphx::make_op("convolution"), c, w);
auto conv2 = m2.add_instruction(migraphx::make_op("convolution"), x, w);
auto sum = m2.add_instruction(migraphx::make_op("add"), conv1, conv2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1) TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
......
...@@ -402,9 +402,10 @@ TEST_CASE(conv_bias_add) ...@@ -402,9 +402,10 @@ TEST_CASE(conv_bias_add)
auto bias = m1.add_parameter("bias", s6); auto bias = m1.add_parameter("bias", s6);
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution", auto c1 = m1.add_instruction(migraphx::make_op("convolution",
...@@ -428,9 +429,10 @@ TEST_CASE(conv_bias_add) ...@@ -428,9 +429,10 @@ TEST_CASE(conv_bias_add)
auto bias = m2.add_parameter("bias", s6); auto bias = m2.add_parameter("bias", s6);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}}, {{"padding", {0, 0, 0, 0}},
...@@ -468,9 +470,10 @@ TEST_CASE(conv_pooling_dot) ...@@ -468,9 +470,10 @@ TEST_CASE(conv_pooling_dot)
auto input = m1.add_parameter("input", s7); auto input = m1.add_parameter("input", s7);
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero); auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
...@@ -515,10 +518,11 @@ TEST_CASE(conv_pooling_dot) ...@@ -515,10 +518,11 @@ TEST_CASE(conv_pooling_dot)
auto input = m2.add_parameter("input", s7); auto input = m2.add_parameter("input", s7);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f); auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
...@@ -572,9 +576,10 @@ TEST_CASE(mobilenet_snippet) ...@@ -572,9 +576,10 @@ TEST_CASE(mobilenet_snippet)
auto input = mm.add_parameter("input", s7); auto input = mm.add_parameter("input", s7);
auto scale = mm.add_literal(0.5f); auto scale = mm.add_literal(0.5f);
auto zero = mm.add_literal(std::int8_t{0}); auto zero = mm.add_literal(std::int8_t{0});
auto zero32 = mm.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero); auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero); auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero);
......
...@@ -1322,6 +1322,46 @@ TEST_CASE(transpose_slice) ...@@ -1322,6 +1322,46 @@ TEST_CASE(transpose_slice)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_slice_unsqueeze)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {4, 1024, 96, 64}});
auto transpose1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}),
transpose1);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {16}}, {"ends", {24}}}),
transpose1);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {40}}}),
transpose1);
m1.add_return({slice1, slice2, slice3});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {4, 1024, 96, 64}});
auto unsq =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 4, 1}}}), unsq);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto sq2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {5}}}), transpose);
auto sq3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({sq1, sq2, sq3});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_diff_perm) TEST_CASE(transpose_slice_diff_perm)
{ {
migraphx::module m1; migraphx::module m1;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::split_single_dyn_dim{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(dynamic_batch)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(multiple_outputs)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins);
auto add0_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
auto add1_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, sm_input);
submod->add_return({add0_ins, add1_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
migraphx::shape tmp_s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
sub_shapes.push_back(tmp_s);
sub_shapes.push_back(tmp_s);
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret0 =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
auto ret1 =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), sm_ins);
mm0->add_return({ret0, ret1});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add0_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
auto add1_ins = mm1->add_instruction(migraphx::make_op("add"), input1, input1);
mm1->add_return({add0_ins, add1_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(broadcast_match)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", sm_shape.lens()}}),
literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -41,10 +41,12 @@ TEST_CASE(make_invalid_target) ...@@ -41,10 +41,12 @@ TEST_CASE(make_invalid_target)
TEST_CASE(targets) TEST_CASE(targets)
{ {
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto ref_target = migraphx::make_target("ref");
#endif
auto ts = migraphx::get_targets(); auto ts = migraphx::get_targets();
EXPECT(ts.size() == 0);
auto ref_t = migraphx::make_target("ref");
ts = migraphx::get_targets();
EXPECT(ts.size() == 1); EXPECT(ts.size() == 1);
} }
......
...@@ -73,7 +73,8 @@ int main(int argc, const char* argv[]) ...@@ -73,7 +73,8 @@ int main(int argc, const char* argv[])
"test_if_literal", "test_if_literal",
"test_select_module_add", "test_select_module_add",
"test_select_module_reduce", "test_select_module_reduce",
"test_select_module_conv"}); "test_select_module_conv",
"test_split_single_dyn_dim"});
rv.disable_test_for("gpu", {"test_conv_bn_add"}); rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -67,15 +67,17 @@ inline void verify_load_save(const migraphx::program& p) ...@@ -67,15 +67,17 @@ inline void verify_load_save(const migraphx::program& p)
EXPECT(p == loaded); EXPECT(p == loaded);
} }
inline void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false) inline void compile_check(migraphx::program& p,
const migraphx::target& t,
migraphx::compile_options c_opts,
bool show_trace = false)
{ {
auto name = t.name(); auto name = t.name();
auto shapes = p.get_output_shapes(); auto shapes = p.get_output_shapes();
std::stringstream ss; std::stringstream ss;
migraphx::compile_options options;
if(show_trace) if(show_trace)
options.trace = migraphx::tracer{std::cout}; c_opts.trace = migraphx::tracer{std::cout};
p.compile(t, options); p.compile(t, c_opts);
if(shapes.size() != p.get_output_shapes().size()) if(shapes.size() != p.get_output_shapes().size())
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
...@@ -115,19 +117,23 @@ void run_verify::validate(const migraphx::target& t, ...@@ -115,19 +117,23 @@ void run_verify::validate(const migraphx::target& t,
} }
std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p, std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
migraphx::parameter_map inputs) const migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const
{ {
migraphx::target t = migraphx::make_target("ref"); migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
compile_check(p, t); compile_check(p, t, c_opts);
return p.eval(std::move(inputs)); return p.eval(std::move(inputs));
} }
std::pair<migraphx::program, std::vector<migraphx::argument>> run_verify::run_target( std::pair<migraphx::program, std::vector<migraphx::argument>>
const migraphx::target& t, migraphx::program p, const migraphx::parameter_map& inputs) const run_verify::run_target(const migraphx::target& t,
migraphx::program p,
const migraphx::parameter_map& inputs,
const migraphx::compile_options& c_opts) const
{ {
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, (trace_target == t.name())); compile_check(p, t, c_opts, (trace_target == t.name()));
migraphx::parameter_map m; migraphx::parameter_map m;
for(auto&& input : inputs) for(auto&& input : inputs)
{ {
...@@ -157,7 +163,9 @@ auto get_hash(const T& x) ...@@ -157,7 +163,9 @@ auto get_hash(const T& x)
return std::hash<T>{}(x); return std::hash<T>{}(x);
} }
void run_verify::verify(const std::string& name, const migraphx::program& p) const void run_verify::verify(const std::string& name,
const migraphx::program& p,
const migraphx::compile_options& c_opts) const
{ {
using result_future = using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>; std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
...@@ -197,13 +205,13 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -197,13 +205,13 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
} }
} }
auto gold_f = detach_async([=] { return run_ref(p, m); }); auto gold_f = detach_async([=] { return run_ref(p, m, c_opts); });
for(const auto& tname : target_names) for(const auto& tname : target_names)
{ {
target_info ti = get_target_info(tname); target_info ti = get_target_info(tname);
auto t = migraphx::make_target(tname); auto t = migraphx::make_target(tname);
results.emplace_back(tname, results.emplace_back(
detach_async([=] { return run_target(t, p, m); }, ti.parallel)); tname, detach_async([=] { return run_target(t, p, m, c_opts); }, ti.parallel));
} }
assert(gold_f.valid()); assert(gold_f.valid());
...@@ -244,7 +252,7 @@ void run_verify::run(int argc, const char* argv[]) const ...@@ -244,7 +252,7 @@ void run_verify::run(int argc, const char* argv[]) const
for(auto&& p : get_programs()) for(auto&& p : get_programs())
{ {
labels[p.section].push_back(p.name); labels[p.section].push_back(p.name);
test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); }); test::add_test_case(p.name, [=] { verify(p.name, p.get_program(), p.compile_options); });
} }
test::driver d{}; test::driver d{};
d.get_case_names = [&](const std::string& name) -> std::vector<std::string> { d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
......
...@@ -40,15 +40,19 @@ struct target_info ...@@ -40,15 +40,19 @@ struct target_info
struct run_verify struct run_verify
{ {
std::vector<migraphx::argument> run_ref(migraphx::program p, std::vector<migraphx::argument> run_ref(migraphx::program p,
migraphx::parameter_map inputs) const; migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const;
std::pair<migraphx::program, std::vector<migraphx::argument>> std::pair<migraphx::program, std::vector<migraphx::argument>>
run_target(const migraphx::target& t, run_target(const migraphx::target& t,
migraphx::program p, migraphx::program p,
const migraphx::parameter_map& inputs) const; const migraphx::parameter_map& inputs,
const migraphx::compile_options& c_opts) const;
void validate(const migraphx::target& t, void validate(const migraphx::target& t,
const migraphx::program& p, const migraphx::program& p,
const migraphx::parameter_map& m) const; const migraphx::parameter_map& m) const;
void verify(const std::string& name, const migraphx::program& p) const; void verify(const std::string& name,
const migraphx::program& p,
const migraphx::compile_options& c_opts) const;
void run(int argc, const char* argv[]) const; void run(int argc, const char* argv[]) const;
target_info get_target_info(const std::string& name) const; target_info get_target_info(const std::string& name) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add_conv_constant : verify_program<test_add_conv_constant>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
auto x = mm->add_parameter("x", s);
auto c = mm->add_literal(migraphx::generate_literal(s, 1));
auto w = mm->add_literal(migraphx::generate_literal(ws, 2));
auto sum = mm->add_instruction(migraphx::make_op("add"), c, x);
mm->add_instruction(migraphx::make_op("convolution"), sum, w);
return p;
}
};
...@@ -33,13 +33,12 @@ struct test_concat_axis_2 : verify_program<test_concat_axis_2> ...@@ -33,13 +33,12 @@ struct test_concat_axis_2 : verify_program<test_concat_axis_2>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {3, 2, 1}}; migraphx::shape s{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2, 1}}; auto x0 = mm->add_parameter("x0", s);
migraphx::shape s2{migraphx::shape::int32_type, {3, 2, 1}}; auto x1 = mm->add_parameter("x1", s);
auto l0 = mm->add_parameter("x", s0); auto x2 = mm->add_parameter("x2", s);
auto l1 = mm->add_parameter("y", s1); auto x3 = mm->add_parameter("x3", s);
auto l2 = mm->add_parameter("z", s2); mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x0, x1, x2, x3);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -37,10 +37,13 @@ struct test_quantizelinear_int32 : verify_program<test_quantizelinear_int32> ...@@ -37,10 +37,13 @@ struct test_quantizelinear_int32 : verify_program<test_quantizelinear_int32>
migraphx::shape sx{migraphx::shape::int32_type, {2, 2, 2}}; migraphx::shape sx{migraphx::shape::int32_type, {2, 2, 2}};
migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}}; migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}};
migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}}; migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}};
auto input1 = mm->add_parameter("x", sx); auto input1 = mm->add_parameter("x", sx);
auto input2 = mm->add_parameter("y_scale", ss); auto input2 = mm->add_parameter("y_scale", ss);
auto input3 = mm->add_parameter("y_zero_point", sz); auto input3 = mm->add_parameter("y_zero_point", sz);
auto r = mm->add_instruction(migraphx::make_op("quantizelinear"), input1, input2, input3); auto input1_float = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), input1);
auto r =
mm->add_instruction(migraphx::make_op("quantizelinear"), input1_float, input2, input3);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
/**
* Test that the split_single_dyn_dim GPU compiler pass produces the same results as ref.
*/
struct test_split_single_dyn_dim : verify_program<test_split_single_dyn_dim>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}});
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input = mm->add_parameter("data", s);
auto broadcast_lit =
mm->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), input, broadcast_lit);
mm->add_return({add_ins});
return p;
}
migraphx::compile_options get_compile_options() const
{
migraphx::compile_options co;
co.split_single_dyn_dim = true;
return co;
};
};
...@@ -24,15 +24,17 @@ ...@@ -24,15 +24,17 @@
#ifndef MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP #ifndef MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP
#define MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP #define MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP
#include <functional>
#include <migraphx/auto_register.hpp> #include <migraphx/auto_register.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <functional> #include <migraphx/compile_options.hpp>
struct program_info struct program_info
{ {
std::string name; std::string name;
std::string section; std::string section;
std::function<migraphx::program()> get_program; std::function<migraphx::program()> get_program;
migraphx::compile_options compile_options;
}; };
void register_program_info(const program_info& pi); void register_program_info(const program_info& pi);
...@@ -45,9 +47,10 @@ struct register_verify_program_action ...@@ -45,9 +47,10 @@ struct register_verify_program_action
{ {
T x; T x;
program_info pi; program_info pi;
pi.name = migraphx::get_type_name<T>(); pi.name = migraphx::get_type_name<T>();
pi.section = x.section(); pi.section = x.section();
pi.get_program = [x] { return x.create_program(); }; pi.get_program = [x] { return x.create_program(); };
pi.compile_options = x.get_compile_options();
register_program_info(pi); register_program_info(pi);
} }
}; };
...@@ -59,6 +62,7 @@ template <class T> ...@@ -59,6 +62,7 @@ template <class T>
struct verify_program : auto_register_verify_program<T> struct verify_program : auto_register_verify_program<T>
{ {
std::string section() const { return "general"; }; std::string section() const { return "general"; };
migraphx::compile_options get_compile_options() const { return migraphx::compile_options{}; };
}; };
#endif #endif
...@@ -63,8 +63,25 @@ def parse_args(): ...@@ -63,8 +63,25 @@ def parse_args():
type=str, type=str,
action='append', action='append',
help='specify input parameter dimension \ help='specify input parameter dimension \
with the following format --input_dim input_name:dim0,dim1,dim2...' with the following format --input-dim input_name:dim0,dim1,dim2...'
) )
parser.add_argument('--target',
type=str,
default='gpu',
help='target to compile and run MIGraphX on')
parser.add_argument('--ort-run',
dest="ort_run",
action='store_true',
default=False,
help='only perform an onnxruntime run')
parser.add_argument('--ort-logging',
dest="ort_logging",
action='store_true',
default=False,
help='Turn on ort VERBOSE logging via session options')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -111,7 +128,7 @@ def get_np_datatype(in_type): ...@@ -111,7 +128,7 @@ def get_np_datatype(in_type):
'uint16_type': np.uint16, 'uint16_type': np.uint16,
'int8_type': np.int8, 'int8_type': np.int8,
'uint8_type': np.uint8, 'uint8_type': np.uint8,
'bool_type': np.bool_ 'bool_type': bool
} }
return datatypes[in_type] return datatypes[in_type]
...@@ -159,7 +176,8 @@ def main(): ...@@ -159,7 +176,8 @@ def main():
if args.verbose: if args.verbose:
print(model) print(model)
model.compile(migraphx.get_target('gpu')) if not args.ort_run:
model.compile(migraphx.get_target(args.target))
params = {} params = {}
test_inputs = {} test_inputs = {}
...@@ -178,10 +196,19 @@ def main(): ...@@ -178,10 +196,19 @@ def main():
test_inputs[name] = test_input test_inputs[name] = test_input
params[name] = migraphx.argument(test_input) params[name] = migraphx.argument(test_input)
pred_migx = np.array(model.run(params)[-1]) if not args.ort_run:
pred_migx = np.array(model.run(params)[-1])
if use_onnx: if use_onnx:
sess = ort.InferenceSession(model_name, providers=[args.provider]) sess_op = ort.SessionOptions()
if args.ort_logging:
sess_op.log_verbosity_level = 0
sess_op.log_severity_level = 0
sess = ort.InferenceSession(model_name,
sess_options=sess_op,
providers=[args.provider])
ort_params = {} ort_params = {}
for input in sess.get_inputs(): for input in sess.get_inputs():
...@@ -239,14 +266,15 @@ def main(): ...@@ -239,14 +266,15 @@ def main():
y_out = sess.run(y, feed_dict=tf_dict) y_out = sess.run(y, feed_dict=tf_dict)
pred_fw = y_out pred_fw = y_out
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance, if not args.ort_run:
args.tolerance, args.verbose) is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
verbose_string = ' Rerun with --verbose for detailed information.' \ args.tolerance, args.verbose)
if not args.verbose else '' verbose_string = ' Rerun with --verbose for detailed information.' \
if is_correct: if not args.verbose else ''
print('PASSED: MIGraphX meets tolerance') if is_correct:
else: print('PASSED: MIGraphX meets tolerance')
print('FAILED: MIGraphX is not within tolerance.' + verbose_string) else:
print('FAILED: MIGraphX is not within tolerance.' + verbose_string)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment