Unverified Commit e6290061 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Remove operators.hpp includes (#2086)

parent e4ef64f4
......@@ -23,9 +23,9 @@
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
......@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto s = ins->get_shape();
std::size_t offset = seg.first * alignment;
assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem);
m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
}
// Replace zero allocation
......@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if(ins->name() != allocation_op)
continue;
assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem);
m.replace_instruction(
ins, make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", 0}}), mem);
}
// Remove scratch parameter if its not used
......
......@@ -27,7 +27,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
......@@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
return m.add_instruction(op, l_img, l_weights);
return m.add_instruction(
migraphx::make_op("convolution", {{"padding_mode", padding_mode}}), l_img, l_weights);
}
TEST_CASE(rewrite_pad)
......
......@@ -24,7 +24,7 @@
#include <iostream>
#include <vector>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
......@@ -90,7 +90,7 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
mm->add_instruction(migraphx::op::dot{}, pa, pb);
mm->add_instruction(migraphx::make_op("dot"), pa, pb);
return p;
};
......
......@@ -26,7 +26,6 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
......
......@@ -26,8 +26,8 @@
#include <migraphx/insert_pad.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/common.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
......@@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
op.padding = {0, 0, 1, 1};
return m.add_instruction(op, l_img, l_weights);
return m.add_instruction(
migraphx::make_op("convolution",
{{"padding_mode", padding_mode}, {"padding", {0, 0, 1, 1}}}),
l_img,
l_weights);
}
TEST_CASE(rewrite_pad)
......
......@@ -24,7 +24,6 @@
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
......
......@@ -24,7 +24,7 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
......
......@@ -24,7 +24,6 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/pass_manager.hpp>
......
......@@ -24,7 +24,8 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <sstream>
#include <migraphx/make_op.hpp>
......@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
}
{
......@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
input);
}
template <class T>
void test_softmax_variations()
void test_softmax_variations(const std::string& name)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
throws_shape(migraphx::make_op(name, {{"axis", axis}}), input);
}
}
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations("logsoftmax"); }
TEST_CASE(softmax) { test_softmax_variations("softmax"); }
TEST_CASE(lstm)
{
......@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type)
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
}
template <class T>
void test_reduce_ops()
void test_reduce_ops(const std::string& name)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::make_op(name),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::make_op(name, {{"axes", {0, 1, 2, 3}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}},
migraphx::make_op(name, {{"axes", {2, 3}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op(name, {{"axes", {0}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}},
migraphx::make_op(name, {{"axes", {-1}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
}
}
// dynamic shape
template <class T>
void test_dyn_reduce_ops()
void test_dyn_reduce_ops(const std::string& name)
{
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})},
T{{-1}},
migraphx::make_op(name, {{"axes", {-1}}}),
input);
}
{
......@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops()
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})},
T{{0}},
migraphx::make_op(name, {{"axes", {0}}}),
input);
}
{
......@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops()
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})},
T{{}},
migraphx::make_op(name),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
throws_shape(T{{4}}, input);
throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
}
}
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max) { test_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean) { test_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod) { test_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum) { test_reduce_ops("reduce_sum"); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops("reduce_sum"); }
TEST_CASE(reshape_shape)
{
......@@ -2962,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5)
input);
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
......
......@@ -22,7 +22,6 @@
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp>
#include "test.hpp"
......
......@@ -24,7 +24,6 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
......
......@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
......@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", outer);
......@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
auto create_program = [&] {
migraphx::module m;
auto x = m.add_parameter("x", outer);
......@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE(simplify_inner_broadcast1)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
......@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
......@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE(simplify_inner_broadcast_scalar)
{
auto b = migraphx::op::multibroadcast{{32, 384}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
......@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
......@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto b = migraphx::op::multibroadcast{{2, 384, 768}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
......@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
......@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}});
auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
......@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::module m1;
{
auto r = migraphx::op::reshape{{3, 2, 4}};
auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
......@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
......@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
......@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......
......@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
......@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_multibroadcasts2)
......@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 0);
}
TEST_CASE(concat_multibroadcasts3)
......@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
}
TEST_CASE(concat_multibroadcasts4)
......@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 3);
}
TEST_CASE(concat_transpose2)
......@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_transpose3)
......@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_transpose4)
......
......@@ -25,7 +25,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_literal : verify_program<gemm_literal>
{
......@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b);
mm->add_instruction(migraphx::make_op("dot"), a, b);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment