Commit 309df0d2 authored by wsttiger's avatar wsttiger
Browse files

Merge branch 'master' into concat

parents 3de56715 76f7ae49
...@@ -17,6 +17,9 @@ using pooling_descriptor = MIGRAPH_MANAGE_PTR(miopenPoolingDescriptor_t, ...@@ -17,6 +17,9 @@ using pooling_descriptor = MIGRAPH_MANAGE_PTR(miopenPoolingDescriptor_t,
miopenDestroyPoolingDescriptor); miopenDestroyPoolingDescriptor);
using activation_descriptor = MIGRAPH_MANAGE_PTR(miopenActivationDescriptor_t, using activation_descriptor = MIGRAPH_MANAGE_PTR(miopenActivationDescriptor_t,
miopenDestroyActivationDescriptor); miopenDestroyActivationDescriptor);
using fusion_plan_descriptor = MIGRAPH_MANAGE_PTR(miopenFusionPlanDescriptor_t,
miopenDestroyFusionPlan);
using fused_operator_args = MIGRAPH_MANAGE_PTR(miopenOperatorArgs_t, miopenDestroyOperatorArgs);
template <class Result, class F, class... Ts> template <class Result, class F, class... Ts>
Result make_obj(F f, Ts... xs) Result make_obj(F f, Ts... xs)
...@@ -84,6 +87,24 @@ inline activation_descriptor make_relu() ...@@ -84,6 +87,24 @@ inline activation_descriptor make_relu()
return ad; return ad;
} }
inline fusion_plan_descriptor make_fusion_plan(const shape& input)
{
auto t = make_tensor(input);
return make_obj<fusion_plan_descriptor>(&miopenCreateFusionPlan, miopenVerticalFusion, t.get());
}
// Temporary hack to workaround memory problems in miopen
inline fusion_plan_descriptor make_fusion_plan(const tensor_descriptor& input)
{
return make_obj<fusion_plan_descriptor>(
&miopenCreateFusionPlan, miopenVerticalFusion, input.get());
}
inline fused_operator_args make_fused_args()
{
return make_obj<fused_operator_args>(&miopenCreateOperatorArgs);
}
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <migraph/auto_contiguous.hpp> #include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp> #include <migraph/simplify_reshapes.hpp>
#include <migraph/simplify_algebra.hpp>
#include <migraph/constant_propagate.hpp>
#include <migraph/eliminate_contiguous.hpp> #include <migraph/eliminate_contiguous.hpp>
#include <migraph/fwd_conv_batchnorm_rewrite.hpp> #include <migraph/fwd_conv_batchnorm_rewrite.hpp>
...@@ -25,14 +27,18 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -25,14 +27,18 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
constant_propagate{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
fuse_ops{},
dead_code_elimination{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
eliminate_workspace{}, eliminate_workspace{},
......
#include <migraph/constant_propagate.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct const_prop_target
{
std::string name() const { return "const_prop"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::constant_propagate{}, migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
void const_add1()
{
migraph::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraph::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraph::program p2;
auto total = p2.add_literal(3);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
void const_add2()
{
migraph::program p1;
auto one = p1.add_parameter("one", {migraph::shape::int32_type, {1}});
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraph::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraph::program p2;
auto total = p2.add_literal(3);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 != p2);
}
void const_add3()
{
migraph::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
migraph::program p2;
auto total = p2.add_literal(5);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
int main()
{
const_add1();
const_add2();
const_add3();
}
...@@ -174,6 +174,38 @@ struct test_add ...@@ -174,6 +174,38 @@ struct test_add
} }
}; };
struct test_triadd
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto sum = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, z);
return p;
}
};
struct test_triadd2
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {2, 3}};
migraph::shape b{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraph::op::broadcast{1, s}, z);
auto sum = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, zb);
return p;
}
};
struct test_add_broadcast struct test_add_broadcast
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -244,6 +276,22 @@ struct test_add_broadcast5 ...@@ -244,6 +276,22 @@ struct test_add_broadcast5
} }
}; };
struct test_triadd_broadcast
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraph::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraph::op::broadcast{0, x->get_shape()}, y);
auto sum = p.add_instruction(migraph::op::add{}, x, by);
p.add_instruction(migraph::op::add{}, sum, z);
return p;
}
};
struct test_softmax struct test_softmax
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -593,11 +641,14 @@ int main() ...@@ -593,11 +641,14 @@ int main()
verify_program<test_concat>(); verify_program<test_concat>();
verify_program<test_concat2>(); verify_program<test_concat2>();
verify_program<test_add>(); verify_program<test_add>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>(); verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>(); verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>(); verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_softmax>(); verify_program<test_softmax>();
verify_program<test_softmax2>(); verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
......
...@@ -239,6 +239,48 @@ void match_args7() ...@@ -239,6 +239,48 @@ void match_args7()
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
void match_either_args1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
void match_either_args2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
void match_either_args3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_all_of1() void match_all_of1()
{ {
migraph::program p; migraph::program p;
...@@ -391,6 +433,10 @@ int main() ...@@ -391,6 +433,10 @@ int main()
match_args6(); match_args6();
match_args7(); match_args7();
match_either_args1();
match_either_args2();
match_either_args3();
match_all_of1(); match_all_of1();
match_all_of2(); match_all_of2();
......
#include <migraph/simplify_algebra.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct simplify_algebra_target
{
std::string name() const { return "simplify_algebra"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::simplify_algebra{}, migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
void simplify_add1()
{
migraph::program p1;
{
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraph::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, x, one);
auto sum2 = p1.add_instruction(migraph::op::add{}, y, two);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
migraph::program p2;
{
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraph::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
void simplify_add2()
{
migraph::program p1;
{
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraph::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, y);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
migraph::program p2;
{
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraph::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
void simplify_add3()
{
migraph::program p1;
{
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
migraph::program p2;
{
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, x);
auto sum2 = p2.add_instruction(migraph::op::add{}, one, two);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
void simplify_add4()
{
migraph::program p1;
{
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraph::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraph::op::add{}, sum1, y);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum2, two);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
migraph::program p2;
{
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraph::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
int main()
{
simplify_add1();
simplify_add2();
simplify_add3();
// simplify_add4();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment