Commit 6d56671b authored by Paul's avatar Paul
Browse files

Formatting

parent 3b64f602
...@@ -413,18 +413,18 @@ inline auto either_arg(std::size_t i, std::size_t j) ...@@ -413,18 +413,18 @@ inline auto either_arg(std::size_t i, std::size_t j)
}; };
} }
template<class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins); auto i = m.match(ctx, ins);
if (i != ctx.not_found() and i->get_shape() == ins->get_shape()) if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins; return ins;
return ctx.not_found(); return ctx.not_found();
}); });
} }
template<class... Ms> template <class... Ms>
auto same_shape(Ms... ms) auto same_shape(Ms... ms)
{ {
return all_of(same_shape(ms)...); return all_of(same_shape(ms)...);
......
...@@ -129,9 +129,7 @@ struct find_nop_reshapes ...@@ -129,9 +129,7 @@ struct find_nop_reshapes
auto reshapes = reshaper_names(); auto reshapes = reshaper_names();
reshapes.insert("transpose"); reshapes.insert("transpose");
reshapes.insert("slice"); reshapes.insert("slice");
return match::name(reshapes)( return match::name(reshapes)(match::same_shape(match::arg(0)));
match::same_shape(match::arg(0))
);
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
...@@ -189,15 +187,15 @@ struct find_concat_transpose ...@@ -189,15 +187,15 @@ struct find_concat_transpose
auto ins = mr.result; auto ins = mr.result;
auto s = ins->inputs().front()->get_shape(); auto s = ins->inputs().front()->get_shape();
assert(s.transposed()); assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator()); auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation); auto ipermutation = invert_permutation(permutation);
op.axis = permutation[op.axis]; op.axis = permutation[op.axis];
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if (i->name() == "transpose" and i->inputs().front()->get_shape().standard()) if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front(); return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i); return p.insert_instruction(ins, op::transpose{permutation}, i);
}); });
...@@ -217,7 +215,12 @@ void simplify_reshapes::apply(program& p) const ...@@ -217,7 +215,12 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions // Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{}, find_transpose{}, find_concat_transpose{}); match::find_matches(p,
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
} }
} }
......
...@@ -217,8 +217,8 @@ TEST_CASE(transpose_partial3) ...@@ -217,8 +217,8 @@ TEST_CASE(transpose_partial3)
TEST_CASE(nop_transpose1) TEST_CASE(nop_transpose1)
{ {
migraphx::program p; migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x); auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
...@@ -248,12 +248,12 @@ TEST_CASE(nop_transpose2) ...@@ -248,12 +248,12 @@ TEST_CASE(nop_transpose2)
TEST_CASE(nop_transpose3) TEST_CASE(nop_transpose3)
{ {
migraphx::program p; migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y); auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat); auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
...@@ -265,27 +265,23 @@ TEST_CASE(nop_transpose3) ...@@ -265,27 +265,23 @@ TEST_CASE(nop_transpose3)
TEST_CASE(concat_transpose1) TEST_CASE(concat_transpose1)
{ {
migraphx::program p; migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y); auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3); EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat = std::find_if(p.begin(), p.end(), [](auto ins) { auto new_concat =
return ins.name() == "concat"; std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
});
EXPECT(bool{new_concat != p.end()}); EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3); EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment