Commit 22cee7ff authored by Paul's avatar Paul
Browse files

Format

parent d0dbaf41
......@@ -12,10 +12,7 @@ struct common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
bool empty() const
{
return dims.empty();
}
bool empty() const { return dims.empty(); }
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
......
......@@ -203,18 +203,18 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
bool matches =
+ fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
return nullopt;
});
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
bool matches = +fold(
[&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
return nullopt;
});
}
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
......
......@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/
struct module
{
using inserter = std::function<instruction_ref(module& m, instruction_ref ins, const operation& op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args)>;
using inserter = std::function<instruction_ref(module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args)>;
module(const std::string& name = "");
// move constructor
......
......@@ -328,7 +328,9 @@ bool instruction::can_eval() const
}
else if(is_context_free(op))
{
assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) { return std::addressof(*arg) == this; }));
assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) {
return std::addressof(*arg) == this;
}));
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
......
......@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins), default_module_inserter());
return insert_generic_instructions(
*this, ins, instructions, std::move(map_ins), default_module_inserter());
}
std::vector<instruction_ref>
module::insert_instructions(module::inserter insert, instruction_ref ins,
module::insert_instructions(module::inserter insert,
instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
......@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter());
return insert_generic_instructions(
*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter());
}
std::vector<instruction_ref>
......@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter());
return insert_generic_instructions(
*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter());
}
instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
......
......@@ -917,10 +917,11 @@ struct find_broadcast_reshaper
struct find_poinwise_reduce_reshape
{
template<class... Ms>
template <class... Ms>
static auto match_reshaper(Ms... ms)
{
return match::name({"reshape", "squeeze", "unsqueeze"})(match::arg(0)(match::skip(match::name("contiguous"))(ms...)));
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::arg(0)(match::skip(match::name("contiguous"))(ms...)));
}
auto matcher() const
{
......@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape
return contains({"broadcast", "multibroadcast"}, op.name());
}
static bool is_broadcast(instruction_ref ins)
{
return is_broadcast(ins->get_operator());
}
static bool is_broadcast(instruction_ref ins) { return is_broadcast(ins->get_operator()); }
static bool is_pointwise(instruction_ref ins)
{
......@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false);
}
static bool is_reduce(instruction_ref ins)
{
return is_reduce(ins->get_operator());
}
static bool is_reduce(instruction_ref ins) { return is_reduce(ins->get_operator()); }
static bool is_reduce(const operation& op)
{
......@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false) or a.get("reduce", false);
}
static std::vector<instruction_ref> topo_sort(instruction_ref entry, const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux)
static std::vector<instruction_ref> topo_sort(instruction_ref entry,
const std::unordered_set<instruction_ref>& inss,
std::unordered_set<instruction_ref>& aux)
{
std::vector<instruction_ref> instructions;
bool has_entry = contains(inss, entry);
fix([&](auto self, instruction_ref ins) {
if (ins != entry or has_entry)
if(ins != entry or has_entry)
instructions.push_back(ins);
for(auto input:ins->inputs())
for(auto input : ins->inputs())
{
if(not contains(inss, input))
aux.insert(input);
}
for(auto output : ins->outputs())
{
if (contains(instructions, output))
if(contains(instructions, output))
continue;
if (not contains(inss, output))
if(not contains(inss, output))
continue;
self(output);
}
......@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape
return instructions;
}
static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux)
static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss,
std::unordered_set<instruction_ref>& aux)
{
std::vector<instruction_ref> instructions;
std::unordered_set<instruction_ref> visited;
for(auto ins:inss)
for(auto ins : inss)
{
fix([&](auto self, instruction_ref child) {
if (contains(visited, child))
if(contains(visited, child))
return;
for(auto output:child->outputs())
for(auto output : child->outputs())
{
if (not contains(inss, output))
if(not contains(inss, output))
continue;
self(output);
}
visited.insert(child);
for(auto input:child->inputs())
for(auto input : child->inputs())
{
if(not contains(inss, input))
aux.insert(input);
......@@ -1025,11 +1023,11 @@ struct find_poinwise_reduce_reshape
auto reshape_ins = r.instructions["reshape"];
auto nelements = x_ins->get_shape().elements();
auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens();
auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens();
auto cd = common_dims::compute(dims1, dims2);
if (cd.empty())
if(cd.empty())
return;
// m.debug_print();
......@@ -1064,28 +1062,30 @@ struct find_poinwise_reduce_reshape
// Collect from output
fix([&](auto self, instruction_ref out) {
// if(contains(inss, out))
// return;
// return;
// std::cout << "Visit: ";
// m.debug_print(out);
// m.debug_print(out->inputs());
auto outputs = out->outputs();
std::sort(outputs.begin(), outputs.end(), by(std::less<>{}, [&](instruction_ref i) {
return std::distance(reshape_ins, i);
}));
return std::distance(reshape_ins, i);
}));
// m.debug_print(outputs);
for(auto output : outputs)
{
if(not std::all_of(
output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or reshape_ins == input or contains(output_inss, input);// or dom.strictly_dominate(reshape_ins, input);
return input->can_eval() or reshape_ins == input or
contains(output_inss,
input); // or dom.strictly_dominate(reshape_ins, input);
}))
continue;
if(not is_pointwise_or_reduce(output) and not is_broadcast(output))
continue;
if (is_reduce(output))
if(is_reduce(output))
{
auto op_axes = output->get_operator().to_value()["axes"].to_vector<int64_t>();
if (axes.empty())
if(axes.empty())
axes = op_axes;
if(axes != op_axes)
return;
......@@ -1096,18 +1096,19 @@ struct find_poinwise_reduce_reshape
})(reshape_ins);
std::vector<int64_t> common_axes;
for(auto axis:axes)
for(auto axis : axes)
{
common_axes.insert(common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end());
common_axes.insert(
common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end());
}
auto common_rdims = cd.dims;
for(auto axis:common_axes)
for(auto axis : common_axes)
{
common_rdims[axis] = 1;
}
// Topological sort
std::unordered_set<instruction_ref> aux;
auto input_instructions = topo_sort(input_inss, aux);
auto input_instructions = topo_sort(input_inss, aux);
auto output_instructions = topo_sort(output_inss, aux);
// std::cout << "output_inss:\n";
// m.debug_print({output_inss.begin(), output_inss.end()});
......@@ -1116,23 +1117,28 @@ struct find_poinwise_reduce_reshape
// std::cout << "aux:\n";
// m.debug_print({aux.begin(), aux.end()});
auto last = output_instructions.back();
auto last = output_instructions.back();
auto insert_reshape = [&](instruction_ref input) {
auto use_rdims = input->get_shape().elements() < nelements;
auto c = m.insert_instruction(last, make_op("contiguous"), input);
return m.insert_instruction(last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c);
auto c = m.insert_instruction(last, make_op("contiguous"), input);
return m.insert_instruction(
last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c);
};
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// map_ins[entry] = insert_reshape(entry);
for(auto i:aux)
for(auto i : aux)
{
map_ins[i] = insert_reshape(i);
}
auto inserter = [&](module& mm, instruction_ref i, operation op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args) {
if (is_reduce(op))
auto inserter = [&](module& mm,
instruction_ref i,
operation op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args) {
if(is_reduce(op))
op.from_value({{"axes", common_axes}});
if (is_broadcast(op))
if(is_broadcast(op))
op.from_value({{"out_lens", cd.dims}});
// std::cout << op << std::endl;
// m.debug_print(args);
......@@ -1141,7 +1147,7 @@ struct find_poinwise_reduce_reshape
auto new_x_ins = m.insert_instructions(inserter, last, input_instructions, map_ins).front();
map_ins[reshape_ins] = new_x_ins;
auto new_last = m.insert_instructions(inserter, last, output_instructions, map_ins).front();
auto new_c = m.insert_instruction(last, make_op("contiguous"), new_last);
auto new_c = m.insert_instruction(last, make_op("contiguous"), new_last);
auto new_reshape = m.insert_instruction(last, make_op("reshape", {{"dims", dims2}}), new_c);
m.debug_print();
m.debug_print(last);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment