Commit c174704f authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent a2092da6
......@@ -19,50 +19,48 @@ void eliminate_concat::apply(program& p) const
return arg->name() == "@literal";
}))
continue;
// We can only do this optimization when concat axis is either the leftmost
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
if (concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin()+concat_op.axis,
[] (auto x) {
return x == 1;
}))
if(concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
auto last = ins->inputs().back();
if (last->name() != concat_opt.allocate()) continue;
if(last->name() != concat_opt.allocate())
continue;
// Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations;
for (auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end()-1; ins2++)
for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++)
{
auto last2 = (*ins2)->inputs().back();
if (last2->name() == concat_opt.allocate())
if(last2->name() == concat_opt.allocate())
{
allocations.push_back(last2);
}
}
// Need to sort the allocations, so that we know where to
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
std::sort(allocations.begin(), allocations.end(), [&] (instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
std::sort(
allocations.begin(), allocations.end(), [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
// Move "super" allocation to the front
auto first = allocations.front();
auto super = p.move_instruction(last, first);
auto first = allocations.front();
auto super = p.move_instruction(last, first);
std::size_t offset = 0;
for (auto x : allocations)
for(auto x : allocations)
{
migraph::op::load op{x->get_shape(), offset};
p.replace_instruction(x, op, {super});
offset += x->get_shape().elements();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end()-1,
std::back_inserter(args));
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args);
}
}
......
......@@ -11,7 +11,7 @@ struct program;
struct eliminate_concat
{
concat_optimization concat_opt;
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(program& p) const;
};
......
......@@ -620,10 +620,7 @@ struct unary
struct identity
{
std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(0);
}
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
......
......@@ -6,35 +6,27 @@
struct concat
{
concat(std::size_t axis)
{
op.axis = axis;
}
concat(std::size_t axis) { op.axis = axis; }
migraph::op::concat op;
std::string name() const { return "eliminate_concat::concat"; }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
return op.compute_shape(inputs);
}
migraph::argument
compute(migraph::context& ctx, const migraph::shape& output_shape, const std::vector<migraph::argument>& args) const
migraph::argument compute(migraph::context& ctx,
const migraph::shape& output_shape,
const std::vector<migraph::argument>& args) const
{
return {output_shape};
}
};
struct concat_test_optimization
struct concat_test_optimization
{
/// A unique name used to identify the concat optimization
std::string name() const
{
return "eliminate_concat::concat";
}
std::string name() const { return "eliminate_concat::concat"; }
/// A unique name used to identify the allocate operator
std::string allocate() const
{
return "allocate";
}
std::string allocate() const { return "allocate"; }
/// Return the lowered concat operator
migraph::op::concat get_concat(const migraph::operation& op) const
{
......@@ -48,7 +40,8 @@ struct eliminate_concat_target
std::string name() const { return "eliminate_target"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::eliminate_concat{concat_test_optimization{}}, migraph::dead_code_elimination{}};
return {migraph::eliminate_concat{concat_test_optimization{}},
migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
......@@ -84,32 +77,39 @@ struct fred_op
{
return args.at(0);
}
};
void basic()
{
auto create_test_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,2,8,8}}});
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,3,8,8}}});
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,5,8,8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,10,8,8}}});
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1,10,8,8}}});
auto l1 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,2,8,8}}, 0}, {a1});
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
auto p1 = p.add_instruction(fred_op{}, l1);
auto l2 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,3,8,8}}, 128}, {a1});
auto l2 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}, 128}, {a1});
auto p2 = p.add_instruction(fred_op{}, l2);
auto l3 = p.add_instruction(migraph::op::load{migraph::shape{migraph::shape::float_type, {1,5,8,8}}, 320}, {a1});
auto l3 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}, 320}, {a1});
auto p3 = p.add_instruction(fred_op{}, l3);
auto i1 = p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3});
return p;
......@@ -126,29 +126,37 @@ void wont_work()
{
auto create_test_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,2,8,8}}});
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,3,8,8}}});
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,5,8,8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,10,8,8}}});
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,2,8,8}}});
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,3,8,8}}});
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,5,8,8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2,10,8,8}}});
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
auto p4 = p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
return p;
};
auto p1 = create_test_program();
......
......@@ -18,7 +18,7 @@ struct program;
#ifdef DOXYGEN
/// An interface for applying an optimization for the concat instruction
struct concat_optimization
struct concat_optimization
{
/// A unique name used to identify the concat optimization
std::string name() const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment