Commit b119ed8f authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into develop

parents 26d1a969 6f1c947f
 slice_var_input_static1:
)
data
starts
ends
axesoutput"Sliceslice_var_input_static1Z
data


Z
starts

Z
ends

Z
axes

b
output


B
\ No newline at end of file
 slice_var_input_steps_error:
0arg_step"Constant*
value**Bstep
3
data
starts
ends
axes
arg_stepoutput"Sliceslice_var_input_steps_errorZ
data


Z
starts

Z
ends

Z
axes

b
output


B
\ No newline at end of file
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <sstream> #include <sstream>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -156,13 +157,13 @@ TEST_CASE(broadcast) ...@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
{ {
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}}; migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input); throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{2, 2}; std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}}; migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input); throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
} }
{ {
...@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape) ...@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
input); input);
} }
template <class T> void test_softmax_variations(const std::string& name)
void test_softmax_variations()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 0}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 1}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 2}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 3}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4; int axis = 4;
throws_shape(T{axis}, input); throws_shape(migraphx::make_op(name, {{"axis", axis}}), input);
} }
} }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); } TEST_CASE(logsoftmax) { test_softmax_variations("logsoftmax"); }
TEST_CASE(softmax) { test_softmax_variations("softmax"); }
TEST_CASE(lstm) TEST_CASE(lstm)
{ {
...@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type) ...@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type)
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros); throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
} }
template <class T> void test_reduce_ops(const std::string& name)
void test_reduce_ops()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::make_op(name),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); migraphx::make_op(name, {{"axes", {0, 1, 2, 3}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}},
migraphx::make_op(name, {{"axes", {2, 3}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op(name, {{"axes", {0}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}},
migraphx::make_op(name, {{"axes", {-1}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input); throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
} }
} }
// dynamic shape // dynamic shape
template <class T> void test_dyn_reduce_ops(const std::string& name)
void test_dyn_reduce_ops()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})}, std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})},
T{{-1}}, migraphx::make_op(name, {{"axes", {-1}}}),
input); input);
} }
{ {
...@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops() ...@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops()
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})}, std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})},
T{{0}}, migraphx::make_op(name, {{"axes", {0}}}),
input); input);
} }
{ {
...@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops() ...@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops()
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})}, std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})},
T{{}}, migraphx::make_op(name),
input); input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
throws_shape(T{{4}}, input); throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
} }
} }
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); } TEST_CASE(reduce_max) { test_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean) { test_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); } TEST_CASE(reduce_prod) { test_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); } TEST_CASE(reduce_sum) { test_reduce_ops("reduce_sum"); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); } TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); } TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); } TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops("reduce_sum"); }
TEST_CASE(reshape_shape) TEST_CASE(reshape_shape)
{ {
...@@ -2822,7 +2839,7 @@ TEST_CASE(select_module_dyn) ...@@ -2822,7 +2839,7 @@ TEST_CASE(select_module_dyn)
input); input);
} }
TEST_CASE(slice_shape) TEST_CASE(slice_static_shape)
{ {
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
...@@ -2840,6 +2857,67 @@ TEST_CASE(slice_shape) ...@@ -2840,6 +2857,67 @@ TEST_CASE(slice_shape)
input); input);
} }
TEST_CASE(slice_var_inputs_static_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_static_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice"),
input,
starts,
ends,
axes);
}
TEST_CASE(slice_var_inputs_static_error0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {3}};
throws_shape(migraphx::make_op("slice"), input, starts, ends, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 4}, {0, 4}}},
migraphx::make_op("slice"),
input,
starts,
ends,
axes);
}
TEST_CASE(slice_dyn_shape0) TEST_CASE(slice_dyn_shape0)
{ {
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}}; migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
...@@ -2870,7 +2948,7 @@ TEST_CASE(slice_dyn_shape2) ...@@ -2870,7 +2948,7 @@ TEST_CASE(slice_dyn_shape2)
TEST_CASE(slice_dyn_shape3) TEST_CASE(slice_dyn_shape3)
{ {
// TODO: When variable dimension slicing is allowed, Slice to a size smaller than min. // TODO: When non-fixed dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error. // Until then, this action is an error.
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}}; migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}};
throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
...@@ -2901,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5) ...@@ -2901,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5)
input); input);
} }
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0) TEST_CASE(softmax_dyn0)
{ {
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}; migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include "test.hpp" #include "test.hpp"
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
......
...@@ -8153,6 +8153,115 @@ TEST_CASE(slice_test) ...@@ -8153,6 +8153,115 @@ TEST_CASE(slice_test)
} }
} }
TEST_CASE(slice_var_inputs_static0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {1};
std::vector<int32_t> end_data = {3};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {-2};
std::vector<int32_t> end_data = {2831};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::float_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int64_type, {3}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice"), l0, starts, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> start_data = {0, 0, 0};
std::vector<int64_t> end_data = {2, 2, 2};
std::vector<int64_t> axes_data = {0, 1, 2};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
std::vector<int> end_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_dyn_test0) TEST_CASE(slice_dyn_test0)
{ {
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is // Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1) ...@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", outer);
...@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2) ...@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
auto create_program = [&] { auto create_program = [&] {
migraphx::module m; migraphx::module m;
auto x = m.add_parameter("x", outer); auto x = m.add_parameter("x", outer);
...@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add) ...@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE(simplify_inner_broadcast1) TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
...@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1) ...@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE(simplify_inner_broadcast2) TEST_CASE(simplify_inner_broadcast2)
{ {
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
...@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2) ...@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE(simplify_inner_broadcast_scalar) TEST_CASE(simplify_inner_broadcast_scalar)
{ {
auto b = migraphx::op::multibroadcast{{32, 384}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
...@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y); auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum); auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb); m2.add_instruction(pass_op{}, sumb);
...@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE(simplify_inner_broadcast_different_dims) TEST_CASE(simplify_inner_broadcast_different_dims)
{ {
auto b = migraphx::op::multibroadcast{{2, 384, 768}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
...@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y); auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum); auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb); m2.add_instruction(pass_op{}, sumb);
...@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE(simplify_inner_broadcast_different_broadcasts) TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{ {
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}});
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}}; auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
...@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) ...@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) ...@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) ...@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) ...@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu) ...@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu) ...@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto r = migraphx::op::reshape{{3, 4}}; auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis) ...@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::module m1; migraphx::module m1;
{ {
auto r = migraphx::op::reshape{{3, 2, 4}}; auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice) ...@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
...@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice) ...@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
...@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice) ...@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) ...@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) ...@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes) ...@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}), migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
...@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) ...@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) ...@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) ...@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) ...@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1) ...@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_multibroadcasts2) TEST_CASE(concat_multibroadcasts2)
...@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2) ...@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 0);
} }
TEST_CASE(concat_multibroadcasts3) TEST_CASE(concat_multibroadcasts3)
...@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3) ...@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
} }
TEST_CASE(concat_multibroadcasts4) TEST_CASE(concat_multibroadcasts4)
...@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1) ...@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 3);
} }
TEST_CASE(concat_transpose2) TEST_CASE(concat_transpose2)
...@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2) ...@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_transpose3) TEST_CASE(concat_transpose3)
...@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3) ...@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_transpose4) TEST_CASE(concat_transpose4)
......
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -840,12 +839,8 @@ TEST_CASE(slice_test) ...@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm->add_literal(migraphx::literal{s0, {1, 0}}); mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}}); mm->add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op; mm->add_instruction(
op.starts = {1, 0}; migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0);
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
mm->add_instruction(op, l0);
auto prog = optimize_tf("slice_test.pb", false); auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test) ...@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4; auto l2 = mm->add_instruction(
migraphx::op::slice op; migraphx::make_op(
op.starts = {0, 0, 0, 0}; "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}),
op.ends = {1, 1, 1, 5}; l1);
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = mm->add_instruction(op, l1);
auto shrink_axis = 1; auto shrink_axis = 1;
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
auto prog = optimize_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
...@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test) ...@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format) // add literals for starts, ends, and strides in tf (NHWC format)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 1, 1, 0}); std::vector<int>{0, 1, 1, 0});
...@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test) ...@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto l1 = auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3}); mm->add_return({l3});
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_literal : verify_program<gemm_literal> struct gemm_literal : verify_program<gemm_literal>
{ {
...@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal> ...@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto a = mm->add_literal(migraphx::generate_literal(a_shape)); auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape); auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b); mm->add_instruction(migraphx::make_op("dot"), a, b);
return p; return p;
} }
......
...@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt ...@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt
# Add newer cmake to the path # Add newer cmake to the path
export PATH="/opt/cmake/bin:$PATH" export PATH="/opt/cmake/bin:$PATH"
export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w" export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w"
./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root ./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --build_wheel --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root
cd build/Linux/Release cd build/Linux/Release
#Add test launcher for onnxrt tests #Add test launcher for onnxrt tests
......
FROM registry.suse.com/suse/sle15:15.4
RUN sh -c 'echo -e "\
[rocm]\n\
name=rocm\n\
baseurl=https://repo.radeon.com/rocm/zyp/5.5/main\n\
enabled=1\n\
gpgcheck=1\n\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\
" > /etc/zypp/repos.d/rocm.repo'
RUN cat /etc/zypp/repos.d/rocm.repo
RUN zypper -n --gpg-auto-import-keys refresh
RUN zypper install -y -t pattern devel_basis enhanced_base
RUN zypper --gpg-auto-import-keys install -y \
doxygen \
gcc-c++ \
gdb \
git \
python3-pip
# Workaround broken rocm packages
RUN ln -s /opt/rocm-* /opt/rocm
RUN echo "/opt/rocm/lib" > /etc/ld.so.conf.d/rocm.conf
RUN echo "/opt/rocm/llvm/lib" > /etc/ld.so.conf.d/rocm-llvm.conf
RUN ldconfig
ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8
# Install yapf
RUN pip3 install yapf==0.28.0
# Install doc requirements
# ADD docs/.sphinx/requirements.txt /doc-requirements.txt
# RUN pip3 install -r /doc-requirements.txt
# Install dependencies
ADD dev-requirements.txt /dev-requirements.txt
ADD requirements.txt /requirements.txt
ADD rbuild.ini /rbuild.ini
COPY ./tools/install_prereqs.sh /
RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh
...@@ -31,9 +31,30 @@ set -e ...@@ -31,9 +31,30 @@ set -e
export LC_ALL=C.UTF-8 export LC_ALL=C.UTF-8
export LANG=C.UTF-8 export LANG=C.UTF-8
source /etc/os-release
if [[ ("${ID}" == "sles") ]]; then
zypper -n --gpg-auto-import-keys install -y \
cmake \
miopen-hip-devel \
openmp-extras-devel \
python3-devel \
python3-pip \
rocblas-devel \
rocm-cmake
else
# Need pip3 and Python headers to build dependencies
apt update && apt install -y \
cmake \
libnuma-dev \
miopen-hip-dev \
openmp-extras \
python3-dev \
python3-pip \
rocblas-dev \
rocm-cmake
fi
# Need pip3 and Python headers to build dependencies
apt update && apt install -y python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages # Needed for cmake to build various pip packages
pip3 install setuptools wheel pip3 install setuptools wheel
...@@ -56,9 +77,11 @@ echo "Dependencies are installed at $PREFIX" ...@@ -56,9 +77,11 @@ echo "Dependencies are installed at $PREFIX"
# Install deps with rbuild # Install deps with rbuild
rbuild prepare -d $PREFIX -s develop rbuild prepare -d $PREFIX -s develop
if [[ ("${ID}" != "sles") ]]; then
export CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=ON" export CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
pip3 install onnx==1.10.2 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==23.0 pip3 install onnx==1.10.2 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==23.0
# pin version of protobuf in Python for onnx runtime unit tests between dist versions # pin version of protobuf in Python for onnx runtime unit tests between dist versions
pip3 install protobuf==3.20.0 pip3 install protobuf==3.20.0
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment