"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "8b10c601680cf2090f95e2819bee5a86a3e20045"
Commit a2092da6 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added optimization pass to remove concat operator when appropriate

parent ceaf5ee0
...@@ -6,6 +6,7 @@ add_library(migraph ...@@ -6,6 +6,7 @@ add_library(migraph
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
......
#include <iterator>
#include <migraph/eliminate_concat.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace migraph {
void eliminate_concat::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// Look for the concat operator
if(ins->name() != concat_opt.name())
continue;
// If any inputs are literals then abort
if(std::any_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
continue;
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
if (concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin()+concat_op.axis,
[] (auto x) {
return x == 1;
}))
{
// Last input should be an allocation
auto last = ins->inputs().back();
if (last->name() != concat_opt.allocate()) continue;
// Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations;
for (auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end()-1; ins2++)
{
auto last2 = (*ins2)->inputs().back();
if (last2->name() == concat_opt.allocate())
{
allocations.push_back(last2);
}
}
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
std::sort(allocations.begin(), allocations.end(), [&] (instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
// Move "super" allocation to the front
auto first = allocations.front();
auto super = p.move_instruction(last, first);
std::size_t offset = 0;
for (auto x : allocations)
{
migraph::op::load op{x->get_shape(), offset};
p.replace_instruction(x, op, {super});
offset += x->get_shape().elements();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end()-1,
std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args);
}
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/concat_opt.hpp>
namespace migraph {
struct program;
struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -617,9 +617,17 @@ struct unary ...@@ -617,9 +617,17 @@ struct unary
} }
}; };
struct identity : unary struct identity
{ {
std::string name() const { return "identity"; } std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(0);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
}; };
struct abs : unary struct abs : unary
......
#include <migraph/eliminate_concat.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct concat
{
concat(std::size_t axis)
{
op.axis = axis;
}
migraph::op::concat op;
std::string name() const { return "eliminate_concat::concat"; }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
return op.compute_shape(inputs);
}
migraph::argument
compute(migraph::context& ctx, const migraph::shape& output_shape, const std::vector<migraph::argument>& args) const
{
return {output_shape};
}
};
struct concat_test_optimization
{
/// A unique name used to identify the concat optimization
std::string name() const
{
return "eliminate_concat::concat";
}
/// A unique name used to identify the allocate operator
std::string allocate() const
{
return "allocate";
}
/// Return the lowered concat operator
migraph::op::concat get_concat(const migraph::operation& op) const
{
return migraph::any_cast<concat>(op).op;
}
};
struct eliminate_concat_target
{
std::size_t align = 32;
std::string name() const { return "eliminate_target"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::eliminate_concat{concat_test_optimization{}}, migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
struct allocate
{
migraph::shape s{};
std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(0);
return s;
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
struct fred_op
{
std::string name() const { return "fred_op"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(1);
return inputs.at(0);
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>& args) const
{
return args.at(0);
}
};
void basic()
{
auto create_test_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,2,8,8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,3,8,8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,5,8,8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,10,8,8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,10,8,8}}});
auto l1 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,2,8,8}}, 0}, {a1});
auto p1 = p.add_instruction(fred_op{}, l1);
auto l2 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,3,8,8}}, 128}, {a1});
auto p2 = p.add_instruction(fred_op{}, l2);
auto l3 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,5,8,8}}, 320}, {a1});
auto p3 = p.add_instruction(fred_op{}, l3);
auto i1 = p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3});
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
void wont_work()
{
auto create_test_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,2,8,8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,3,8,8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,5,8,8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,10,8,8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,2,8,8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,3,8,8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,5,8,8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,10,8,8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
int main()
{
setenv("MIGRAPH_DISABLE_MEMORY_COLORING", "1", 1);
basic();
wont_work();
}
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
namespace migraph {
struct program;
#ifdef DOXYGEN
/// An interface for applying an optimization for the concat instruction
struct concat_optimization
{
/// A unique name used to identify the concat optimization
std::string name() const;
/// A unique name used to identify the allocate operator
std::string allocate() const;
/// Return the lowered concat operator
op::concat get_concat(const operation& op) const;
};
#else
<%
interface('concat_optimization',
virtual('name', returns='std::string', const=True),
virtual('allocate', returns='std::string', const=True),
virtual('get_concat', returns='op::concat', op='const operation&', const=True)
)
%>
#endif
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment