Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_clip : verify_program<test_clip> struct test_clip : verify_program<test_clip>
{ {
...@@ -13,9 +13,11 @@ struct test_clip : verify_program<test_clip> ...@@ -13,9 +13,11 @@ struct test_clip : verify_program<test_clip>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, min_val); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}),
max_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, max_val); min_val);
mm->add_instruction(migraphx::op::clip{}, x, min_val, max_val); max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}),
max_val);
mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_axis_0 : verify_program<test_concat_axis_0> struct test_concat_axis_0 : verify_program<test_concat_axis_0>
{ {
...@@ -17,7 +17,7 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0> ...@@ -17,7 +17,7 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0>
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_axis_1 : verify_program<test_concat_axis_1> struct test_concat_axis_1 : verify_program<test_concat_axis_1>
{ {
...@@ -17,7 +17,7 @@ struct test_concat_axis_1 : verify_program<test_concat_axis_1> ...@@ -17,7 +17,7 @@ struct test_concat_axis_1 : verify_program<test_concat_axis_1>
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1> struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1>
{ {
...@@ -17,7 +17,7 @@ struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1> ...@@ -17,7 +17,7 @@ struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1>
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_pooling : verify_program<test_concat_pooling> struct test_concat_pooling : verify_program<test_concat_pooling>
{ {
...@@ -12,13 +12,19 @@ struct test_concat_pooling : verify_program<test_concat_pooling> ...@@ -12,13 +12,19 @@ struct test_concat_pooling : verify_program<test_concat_pooling>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}});
auto transpose = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, input); auto transpose =
auto concat = mm->add_instruction(migraphx::op::concat{3}, transpose); mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), input);
auto concat_t = mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, concat); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), transpose);
auto concat_t =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), concat);
auto pooling = auto pooling = mm->add_instruction(migraphx::make_op("pooling",
mm->add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {8, 8}}, concat_t); {{"mode", "average"},
mm->add_instruction(migraphx::op::relu{}, pooling); {"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {8, 8}}}),
concat_t);
mm->add_instruction(migraphx::make_op("relu"), pooling);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_relu : verify_program<test_concat_relu> struct test_concat_relu : verify_program<test_concat_relu>
{ {
...@@ -17,11 +17,11 @@ struct test_concat_relu : verify_program<test_concat_relu> ...@@ -17,11 +17,11 @@ struct test_concat_relu : verify_program<test_concat_relu>
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
auto r0 = mm->add_instruction(migraphx::op::relu{}, l0); auto r0 = mm->add_instruction(migraphx::make_op("relu"), l0);
auto r1 = mm->add_instruction(migraphx::op::relu{}, l1); auto r1 = mm->add_instruction(migraphx::make_op("relu"), l1);
auto r2 = mm->add_instruction(migraphx::op::relu{}, l2); auto r2 = mm->add_instruction(migraphx::make_op("relu"), l2);
auto c0 = mm->add_instruction(migraphx::op::concat{axis}, r0, r1, r2); auto c0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), r0, r1, r2);
mm->add_instruction(migraphx::op::relu{}, c0); mm->add_instruction(migraphx::make_op("relu"), c0);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_transpose : verify_program<test_concat_transpose> struct test_concat_transpose : verify_program<test_concat_transpose>
{ {
...@@ -16,9 +16,9 @@ struct test_concat_transpose : verify_program<test_concat_transpose> ...@@ -16,9 +16,9 @@ struct test_concat_transpose : verify_program<test_concat_transpose>
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto lp1 = mm->add_parameter("y", s1); auto lp1 = mm->add_parameter("y", s1);
auto l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp1); auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_transpose2 : verify_program<test_concat_transpose2> struct test_concat_transpose2 : verify_program<test_concat_transpose2>
{ {
...@@ -17,8 +17,8 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2> ...@@ -17,8 +17,8 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2>
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto lp2 = mm->add_parameter("z", s2); auto lp2 = mm->add_parameter("z", s2);
auto l2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp2); auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp2);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_concat_transpose3 : verify_program<test_concat_transpose3> struct test_concat_transpose3 : verify_program<test_concat_transpose3>
{ {
...@@ -16,10 +16,10 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3> ...@@ -16,10 +16,10 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3>
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto lp1 = mm->add_parameter("y", s1); auto lp1 = mm->add_parameter("y", s1);
auto l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp1); auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp1);
auto lp2 = mm->add_parameter("z", s2); auto lp2 = mm->add_parameter("z", s2);
auto l2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp2); auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp2);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p; return p;
} }
}; };
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <cassert> #include <cassert>
struct test_contiguous : verify_program<test_contiguous> struct test_contiguous : verify_program<test_contiguous>
...@@ -13,7 +14,7 @@ struct test_contiguous : verify_program<test_contiguous> ...@@ -13,7 +14,7 @@ struct test_contiguous : verify_program<test_contiguous>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::contiguous{}, x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <cassert> #include <cassert>
struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast> struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast>
...@@ -13,7 +14,7 @@ struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast> ...@@ -13,7 +14,7 @@ struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}}; migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::contiguous{}, x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <cassert> #include <cassert>
struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broadcast_transpose> struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broadcast_transpose>
...@@ -13,7 +14,7 @@ struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broa ...@@ -13,7 +14,7 @@ struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broa
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}}; migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::contiguous{}, x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -14,7 +14,7 @@ struct test_conv : verify_program<test_conv> ...@@ -14,7 +14,7 @@ struct test_conv : verify_program<test_conv>
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::convolution{}, input, weights); mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv2 : verify_program<test_conv2> struct test_conv2 : verify_program<test_conv2>
{ {
...@@ -14,7 +14,11 @@ struct test_conv2 : verify_program<test_conv2> ...@@ -14,7 +14,11 @@ struct test_conv2 : verify_program<test_conv2>
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}});
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
mm->add_instruction(migraphx::op::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights); mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv3d : verify_program<test_conv3d> struct test_conv3d : verify_program<test_conv3d>
{ {
...@@ -15,7 +15,11 @@ struct test_conv3d : verify_program<test_conv3d> ...@@ -15,7 +15,11 @@ struct test_conv3d : verify_program<test_conv3d>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::op::convolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, input, weights); migraphx::make_op(
"convolution",
{{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
input,
weights);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add : verify_program<test_conv_add> struct test_conv_add : verify_program<test_conv_add>
{ {
...@@ -16,10 +16,10 @@ struct test_conv_add : verify_program<test_conv_add> ...@@ -16,10 +16,10 @@ struct test_conv_add : verify_program<test_conv_add>
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v = mm->add_literal( auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2)); migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(migraphx::op::exp{}, sum); mm->add_instruction(migraphx::make_op("exp"), sum);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides> struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides>
{ {
...@@ -16,10 +16,11 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -16,10 +16,11 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v = mm->add_literal( auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2)); migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v); auto conv2 = mm->add_instruction(
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2); migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
mm->add_instruction(migraphx::op::exp{}, sum); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(migraphx::make_op("exp"), sum);
return p; return p;
} }
}; };
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
...@@ -16,18 +17,21 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> ...@@ -16,18 +17,21 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}}; {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(l0); auto bias = mm->add_literal(l0);
auto conv = mm->add_instruction(migraphx::op::convolution{}, input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_add = auto bcast_add = mm->add_instruction(
mm->add_instruction(migraphx::op::broadcast{1, conv->get_shape().lens()}, bias); migraphx::make_op("broadcast", {{"axis", 1}, {"dims", conv->get_shape().lens()}}),
auto bias_add = mm->add_instruction(migraphx::op::add{}, conv, bcast_add); bias);
auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add);
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, min_val); min_val = mm->add_instruction(
max_val = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, max_val); migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
mm->add_instruction(migraphx::op::clip{}, bias_add, min_val, max_val); max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), bias_add, min_val, max_val);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv_bn : verify_program<test_conv_bn> struct test_conv_bn : verify_program<test_conv_bn>
{ {
...@@ -14,15 +14,19 @@ struct test_conv_bn : verify_program<test_conv_bn> ...@@ -14,15 +14,19 @@ struct test_conv_bn : verify_program<test_conv_bn>
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}}; migraphx::shape vars{migraphx::shape::float_type, {64}};
auto x = mm->add_parameter("x", xs); auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws); auto w = mm->add_parameter("w", ws);
auto conv = mm->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = mm->add_instruction(
auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); migraphx::make_op("convolution",
auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); x,
w);
auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
mm->add_instruction( mm->add_instruction(
migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
...@@ -14,17 +14,26 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -14,17 +14,26 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}}; migraphx::shape vars{migraphx::shape::float_type, {64}};
auto x = mm->add_parameter("x", xs); auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws); auto w = mm->add_parameter("w", ws);
auto conv = mm->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = mm->add_instruction(
auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); migraphx::make_op("convolution",
auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); x,
w);
auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto bn = mm->add_instruction( auto bn = mm->add_instruction(
migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance);
auto relu = mm->add_instruction(migraphx::op::relu{}, bn); auto relu = mm->add_instruction(migraphx::make_op("relu"), bn);
mm->add_instruction(migraphx::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {1, 1}},
{"stride", {2, 2}},
{"lengths", {3, 3}}}),
relu);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment