Commit 50e6d5eb authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Flatten nested concats (#391)

* Flatten nested concats

* Formatting

* Rename tests
parent 756c5908
......@@ -179,6 +179,38 @@ struct find_concat_transpose
}
};
struct find_nested_concat
{
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
}
static std::size_t get_axis(instruction_ref ins)
{
auto op = any_cast<op::concat>(ins->get_operator());
return op.axis;
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto axis = get_axis(ins);
std::vector<instruction_ref> args;
fix([&](auto self, auto&& inputs) {
for(auto&& i : inputs)
{
if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
self(i->inputs());
else
args.push_back(i);
}
})(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args);
}
};
void simplify_reshapes::apply(program& p) const
{
for(int i = 0; i < 2; i++)
......@@ -196,7 +228,8 @@ void simplify_reshapes::apply(program& p) const
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
find_concat_transpose{},
find_nested_concat{});
}
}
}
......
......@@ -2,6 +2,7 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -328,4 +329,42 @@ TEST_CASE(concat_transpose3)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
TEST_CASE(nested_concat)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}
TEST_CASE(nested_concat_partial)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto l = p.add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment