Commit c174704f authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent a2092da6
...@@ -19,50 +19,48 @@ void eliminate_concat::apply(program& p) const ...@@ -19,50 +19,48 @@ void eliminate_concat::apply(program& p) const
return arg->name() == "@literal"; return arg->name() == "@literal";
})) }))
continue; continue;
// We can only do this optimization when concat axis is either the leftmost // We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1 // axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical // Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input // we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens(); auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator()); auto concat_op = concat_opt.get_concat(ins->get_operator());
if (concat_op.axis == 0 || if(concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin()+concat_op.axis, std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
[] (auto x) {
return x == 1;
}))
{ {
// Last input should be an allocation // Last input should be an allocation
auto last = ins->inputs().back(); auto last = ins->inputs().back();
if (last->name() != concat_opt.allocate()) continue; if(last->name() != concat_opt.allocate())
continue;
// Where are the allocations for the tensors to be concatenated? // Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations; std::vector<instruction_ref> allocations;
for (auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end()-1; ins2++) for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++)
{ {
auto last2 = (*ins2)->inputs().back(); auto last2 = (*ins2)->inputs().back();
if (last2->name() == concat_opt.allocate()) if(last2->name() == concat_opt.allocate())
{ {
allocations.push_back(last2); allocations.push_back(last2);
} }
} }
// Need to sort the allocations, so that we know where to // Need to sort the allocations, so that we know where to
// insert the "super"-allocation // insert the "super"-allocation
std::sort(allocations.begin(), allocations.end(), [&] (instruction_ref x, instruction_ref y) { std::sort(
return std::distance(p.begin(), x) < std::distance(p.begin(), y); allocations.begin(), allocations.end(), [&](instruction_ref x, instruction_ref y) {
}); return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
// Move "super" allocation to the front // Move "super" allocation to the front
auto first = allocations.front(); auto first = allocations.front();
auto super = p.move_instruction(last, first); auto super = p.move_instruction(last, first);
std::size_t offset = 0; std::size_t offset = 0;
for (auto x : allocations) for(auto x : allocations)
{ {
migraph::op::load op{x->get_shape(), offset}; migraph::op::load op{x->get_shape(), offset};
p.replace_instruction(x, op, {super}); p.replace_instruction(x, op, {super});
offset += x->get_shape().elements(); offset += x->get_shape().elements();
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end()-1, std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args); p.replace_instruction(ins, migraph::op::identity{}, args);
} }
} }
......
...@@ -11,7 +11,7 @@ struct program; ...@@ -11,7 +11,7 @@ struct program;
struct eliminate_concat struct eliminate_concat
{ {
concat_optimization concat_opt; concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; } std::string name() const { return "eliminate_concat"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -620,10 +620,7 @@ struct unary ...@@ -620,10 +620,7 @@ struct unary
struct identity struct identity
{ {
std::string name() const { return "identity"; } std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
{
return inputs.at(0);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
......
...@@ -6,35 +6,27 @@ ...@@ -6,35 +6,27 @@
struct concat struct concat
{ {
concat(std::size_t axis) concat(std::size_t axis) { op.axis = axis; }
{
op.axis = axis;
}
migraph::op::concat op; migraph::op::concat op;
std::string name() const { return "eliminate_concat::concat"; } std::string name() const { return "eliminate_concat::concat"; }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{ {
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
migraph::argument migraph::argument compute(migraph::context& ctx,
compute(migraph::context& ctx, const migraph::shape& output_shape, const std::vector<migraph::argument>& args) const const migraph::shape& output_shape,
const std::vector<migraph::argument>& args) const
{ {
return {output_shape}; return {output_shape};
} }
}; };
struct concat_test_optimization struct concat_test_optimization
{ {
/// A unique name used to identify the concat optimization /// A unique name used to identify the concat optimization
std::string name() const std::string name() const { return "eliminate_concat::concat"; }
{
return "eliminate_concat::concat";
}
/// A unique name used to identify the allocate operator /// A unique name used to identify the allocate operator
std::string allocate() const std::string allocate() const { return "allocate"; }
{
return "allocate";
}
/// Return the lowered concat operator /// Return the lowered concat operator
migraph::op::concat get_concat(const migraph::operation& op) const migraph::op::concat get_concat(const migraph::operation& op) const
{ {
...@@ -48,7 +40,8 @@ struct eliminate_concat_target ...@@ -48,7 +40,8 @@ struct eliminate_concat_target
std::string name() const { return "eliminate_target"; } std::string name() const { return "eliminate_target"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraph::pass> get_passes(migraph::context&) const
{ {
return {migraph::eliminate_concat{concat_test_optimization{}}, migraph::dead_code_elimination{}}; return {migraph::eliminate_concat{concat_test_optimization{}},
migraph::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
...@@ -84,32 +77,39 @@ struct fred_op ...@@ -84,32 +77,39 @@ struct fred_op
{ {
return args.at(0); return args.at(0);
} }
}; };
void basic() void basic()
{ {
auto create_test_program = []() { auto create_test_program = []() {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,2,8,8}}}); auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,3,8,8}}}); auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,5,8,8}}}); auto a3 =
auto p3 = p.add_instruction(fred_op{}, a3); p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,10,8,8}}}); auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4); auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = []() { auto create_control_program = []() {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,10,8,8}}}); auto a1 =
auto l1 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,2,8,8}}, 0}, {a1}); p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
auto p1 = p.add_instruction(fred_op{}, l1); auto p1 = p.add_instruction(fred_op{}, l1);
auto l2 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,3,8,8}}, 128}, {a1}); auto l2 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}, 128}, {a1});
auto p2 = p.add_instruction(fred_op{}, l2); auto p2 = p.add_instruction(fred_op{}, l2);
auto l3 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,5,8,8}}, 320}, {a1}); auto l3 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}, 320}, {a1});
auto p3 = p.add_instruction(fred_op{}, l3); auto p3 = p.add_instruction(fred_op{}, l3);
auto i1 = p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3}); auto i1 = p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3});
return p; return p;
...@@ -126,29 +126,37 @@ void wont_work() ...@@ -126,29 +126,37 @@ void wont_work()
{ {
auto create_test_program = []() { auto create_test_program = []() {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,2,8,8}}}); auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,3,8,8}}}); auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,5,8,8}}}); auto a3 =
auto p3 = p.add_instruction(fred_op{}, a3); p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,10,8,8}}}); auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4); auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = []() { auto create_control_program = []() {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,2,8,8}}}); auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,3,8,8}}}); auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,5,8,8}}}); auto a3 =
auto p3 = p.add_instruction(fred_op{}, a3); p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,10,8,8}}}); auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4); auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto p1 = create_test_program(); auto p1 = create_test_program();
......
...@@ -18,7 +18,7 @@ struct program; ...@@ -18,7 +18,7 @@ struct program;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for applying an optimization for the concat instruction /// An interface for applying an optimization for the concat instruction
struct concat_optimization struct concat_optimization
{ {
/// A unique name used to identify the concat optimization /// A unique name used to identify the concat optimization
std::string name() const; std::string name() const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment