Unverified Commit e6290061 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Remove operators.hpp includes (#2086)

parent e4ef64f4
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
*/ */
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const ...@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto s = ins->get_shape(); auto s = ins->get_shape();
std::size_t offset = seg.first * alignment; std::size_t offset = seg.first * alignment;
assert(offset < n); assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem); m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
} }
// Replace zero allocation // Replace zero allocation
...@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const ...@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if(ins->name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
assert(ins->get_shape().bytes() == 0); assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem); m.replace_instruction(
ins, make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", 0}}), mem);
} }
// Remove scratch parameter if its not used // Remove scratch parameter if its not used
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img, ...@@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3); std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op; return m.add_instruction(
op.padding_mode = padding_mode; migraphx::make_op("convolution", {{"padding_mode", padding_mode}}), l_img, l_weights);
return m.add_instruction(op, l_img, l_weights);
} }
TEST_CASE(rewrite_pad) TEST_CASE(rewrite_pad)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -90,7 +90,7 @@ TEST_CASE(int8_quantization) ...@@ -90,7 +90,7 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
mm->add_instruction(migraphx::op::dot{}, pa, pb); mm->add_instruction(migraphx::make_op("dot"), pa, pb);
return p; return p;
}; };
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/common.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img, ...@@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3); std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op; return m.add_instruction(
op.padding_mode = padding_mode; migraphx::make_op("convolution",
op.padding = {0, 0, 1, 1}; {{"padding_mode", padding_mode}, {"padding", {0, 0, 1, 1}}}),
return m.add_instruction(op, l_img, l_weights); l_img,
l_weights);
} }
TEST_CASE(rewrite_pad) TEST_CASE(rewrite_pad)
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/layout_nhwc.hpp> #include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <sstream> #include <sstream>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -156,13 +157,13 @@ TEST_CASE(broadcast) ...@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
{ {
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}}; migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input); throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{2, 2}; std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}}; migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input); throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
} }
{ {
...@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape) ...@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
input); input);
} }
template <class T> void test_softmax_variations(const std::string& name)
void test_softmax_variations()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 0}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 1}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 2}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 3}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4; int axis = 4;
throws_shape(T{axis}, input); throws_shape(migraphx::make_op(name, {{"axis", axis}}), input);
} }
} }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); } TEST_CASE(logsoftmax) { test_softmax_variations("logsoftmax"); }
TEST_CASE(softmax) { test_softmax_variations("softmax"); }
TEST_CASE(lstm) TEST_CASE(lstm)
{ {
...@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type) ...@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type)
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros); throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
} }
template <class T> void test_reduce_ops(const std::string& name)
void test_reduce_ops()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::make_op(name),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); migraphx::make_op(name, {{"axes", {0, 1, 2, 3}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}},
migraphx::make_op(name, {{"axes", {2, 3}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op(name, {{"axes", {0}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}},
migraphx::make_op(name, {{"axes", {-1}}}),
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input); throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
} }
} }
// dynamic shape // dynamic shape
template <class T> void test_dyn_reduce_ops(const std::string& name)
void test_dyn_reduce_ops()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})}, std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})},
T{{-1}}, migraphx::make_op(name, {{"axes", {-1}}}),
input); input);
} }
{ {
...@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops() ...@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops()
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})}, std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})},
T{{0}}, migraphx::make_op(name, {{"axes", {0}}}),
input); input);
} }
{ {
...@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops() ...@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops()
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})}, std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})},
T{{}}, migraphx::make_op(name),
input); input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
throws_shape(T{{4}}, input); throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
} }
} }
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); } TEST_CASE(reduce_max) { test_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean) { test_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); } TEST_CASE(reduce_prod) { test_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); } TEST_CASE(reduce_sum) { test_reduce_ops("reduce_sum"); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); } TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); } TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); } TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops("reduce_sum"); }
TEST_CASE(reshape_shape) TEST_CASE(reshape_shape)
{ {
...@@ -2962,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5) ...@@ -2962,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5)
input); input);
} }
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0) TEST_CASE(softmax_dyn0)
{ {
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}; migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include "test.hpp" #include "test.hpp"
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1) ...@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", outer);
...@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2) ...@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
auto create_program = [&] { auto create_program = [&] {
migraphx::module m; migraphx::module m;
auto x = m.add_parameter("x", outer); auto x = m.add_parameter("x", outer);
...@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add) ...@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE(simplify_inner_broadcast1) TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
...@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1) ...@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE(simplify_inner_broadcast2) TEST_CASE(simplify_inner_broadcast2)
{ {
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
...@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2) ...@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE(simplify_inner_broadcast_scalar) TEST_CASE(simplify_inner_broadcast_scalar)
{ {
auto b = migraphx::op::multibroadcast{{32, 384}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
...@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y); auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum); auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb); m2.add_instruction(pass_op{}, sumb);
...@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE(simplify_inner_broadcast_different_dims) TEST_CASE(simplify_inner_broadcast_different_dims)
{ {
auto b = migraphx::op::multibroadcast{{2, 384, 768}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
...@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y); auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum); auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb); m2.add_instruction(pass_op{}, sumb);
...@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE(simplify_inner_broadcast_different_broadcasts) TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{ {
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}});
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}}; auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
...@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) ...@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) ...@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) ...@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) ...@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu) ...@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu) ...@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto r = migraphx::op::reshape{{3, 4}}; auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis) ...@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::module m1; migraphx::module m1;
{ {
auto r = migraphx::op::reshape{{3, 2, 4}}; auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice) ...@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
...@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice) ...@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
...@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice) ...@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) ...@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) ...@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes) ...@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}), migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
...@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) ...@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) ...@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) ...@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) ...@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1) ...@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_multibroadcasts2) TEST_CASE(concat_multibroadcasts2)
...@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2) ...@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 0);
} }
TEST_CASE(concat_multibroadcasts3) TEST_CASE(concat_multibroadcasts3)
...@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3) ...@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
} }
TEST_CASE(concat_multibroadcasts4) TEST_CASE(concat_multibroadcasts4)
...@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1) ...@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 3);
} }
TEST_CASE(concat_transpose2) TEST_CASE(concat_transpose2)
...@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2) ...@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_transpose3) TEST_CASE(concat_transpose3)
...@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3) ...@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_transpose4) TEST_CASE(concat_transpose4)
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_literal : verify_program<gemm_literal> struct gemm_literal : verify_program<gemm_literal>
{ {
...@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal> ...@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto a = mm->add_literal(migraphx::generate_literal(a_shape)); auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape); auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b); mm->add_instruction(migraphx::make_op("dot"), a, b);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment