Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_3args : verify_program<gemm_multi_3args> struct gemm_multi_3args : verify_program<gemm_multi_3args>
{ {
...@@ -19,7 +19,8 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args> ...@@ -19,7 +19,8 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{ {
...@@ -19,7 +19,8 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -19,7 +19,8 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f; float alpha = 0.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
{ {
...@@ -19,7 +19,8 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> ...@@ -19,7 +19,8 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
{ {
...@@ -19,7 +19,8 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> ...@@ -19,7 +19,8 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2> struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
{ {
...@@ -15,7 +15,7 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2> ...@@ -15,7 +15,7 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::op::dot{}, l1, l2); mm->add_instruction(migraphx::make_op("dot"), l1, l2);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
{ {
...@@ -15,7 +15,7 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> ...@@ -15,7 +15,7 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::op::dot{}, l1, l2); mm->add_instruction(migraphx::make_op("dot"), l1, l2);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_transpose : verify_program<gemm_multi_transpose> struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
{ {
...@@ -14,11 +14,11 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -14,11 +14,11 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto tl2 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, l2); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), l2);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, tl2); mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, tl2);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_conv : verify_program<quant_conv> struct quant_conv : verify_program<quant_conv>
{ {
...@@ -14,7 +14,7 @@ struct quant_conv : verify_program<quant_conv> ...@@ -14,7 +14,7 @@ struct quant_conv : verify_program<quant_conv>
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{}, pa, pc); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_padding : verify_program<quant_conv_padding> struct quant_conv_padding : verify_program<quant_conv_padding>
{ {
...@@ -14,7 +14,10 @@ struct quant_conv_padding : verify_program<quant_conv_padding> ...@@ -14,7 +14,10 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}),
pa,
pc);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
{ {
...@@ -14,7 +14,10 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> ...@@ -14,7 +14,10 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc); mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
pa,
pc);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
{ {
...@@ -17,7 +17,7 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> ...@@ -17,7 +17,7 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{}, l1, l2, l3); mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
{ {
...@@ -15,10 +15,11 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> ...@@ -15,10 +15,11 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3); mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
{ {
...@@ -16,9 +16,10 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> ...@@ -16,9 +16,10 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l2); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3); mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
{ {
...@@ -15,11 +15,12 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> ...@@ -15,11 +15,12 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l2); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3); mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_abs : verify_program<test_abs> struct test_abs : verify_program<test_abs>
{ {
...@@ -11,7 +11,7 @@ struct test_abs : verify_program<test_abs> ...@@ -11,7 +11,7 @@ struct test_abs : verify_program<test_abs>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::abs{}, x); mm->add_instruction(migraphx::make_op("abs"), x);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_acos : verify_program<test_acos> struct test_acos : verify_program<test_acos>
{ {
...@@ -12,7 +12,7 @@ struct test_acos : verify_program<test_acos> ...@@ -12,7 +12,7 @@ struct test_acos : verify_program<test_acos>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::acos{}, x); mm->add_instruction(migraphx::make_op("acos"), x);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_acosh : verify_program<test_acosh> struct test_acosh : verify_program<test_acosh>
{ {
...@@ -14,10 +14,12 @@ struct test_acosh : verify_program<test_acosh> ...@@ -14,10 +14,12 @@ struct test_acosh : verify_program<test_acosh>
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.1f); auto min_val = mm->add_literal(1.1f);
auto max_val = mm->add_literal(100.0f); auto max_val = mm->add_literal(100.0f);
min_val = mm->add_instruction(migraphx::op::multibroadcast{{16}}, min_val); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
max_val = mm->add_instruction(migraphx::op::multibroadcast{{16}}, max_val); min_val);
auto cx = mm->add_instruction(migraphx::op::clip{}, x, min_val, max_val); max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
mm->add_instruction(migraphx::op::acosh{}, cx); max_val);
auto cx = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val);
mm->add_instruction(migraphx::make_op("acosh"), cx);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_add : verify_program<test_add> struct test_add : verify_program<test_add>
{ {
...@@ -13,7 +13,7 @@ struct test_add : verify_program<test_add> ...@@ -13,7 +13,7 @@ struct test_add : verify_program<test_add>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::op::add{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
return p; return p;
} }
}; };
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_add_broadcast : verify_program<test_add_broadcast> struct test_add_broadcast : verify_program<test_add_broadcast>
...@@ -14,8 +15,9 @@ struct test_add_broadcast : verify_program<test_add_broadcast> ...@@ -14,8 +15,9 @@ struct test_add_broadcast : verify_program<test_add_broadcast>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = mm->add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y); auto by = mm->add_instruction(
mm->add_instruction(migraphx::op::add{}, x, by); migraphx::make_op("broadcast", {{"axis", 0}, {"dims", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p; return p;
} }
}; };
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_add_broadcast2 : verify_program<test_add_broadcast2> struct test_add_broadcast2 : verify_program<test_add_broadcast2>
...@@ -14,8 +15,9 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2> ...@@ -14,8 +15,9 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = mm->add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y); auto by = mm->add_instruction(
mm->add_instruction(migraphx::op::add{}, x, by); migraphx::make_op("broadcast", {{"axis", 1}, {"dims", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment