Commit 22cee7ff authored by Paul's avatar Paul
Browse files

Format

parent d0dbaf41
...@@ -12,10 +12,7 @@ struct common_dims ...@@ -12,10 +12,7 @@ struct common_dims
{ {
static common_dims compute(const std::vector<std::size_t>& dims1, static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2); const std::vector<std::size_t>& dims2);
bool empty() const bool empty() const { return dims.empty(); }
{
return dims.empty();
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1; std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2; std::vector<std::vector<std::size_t>> axes_map2;
......
...@@ -203,18 +203,18 @@ struct basic_matcher ...@@ -203,18 +203,18 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_basic_fun_matcher([=](matcher_context& ctx, return make_basic_fun_matcher(
instruction_ref ins) -> optional<instruction_ref> { [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result) if(result)
{ {
bool matches = bool matches = +fold(
+ fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...); [&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches) if(matches)
return result; return result;
} }
return nullopt; return nullopt;
}); });
} }
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
......
...@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins ...@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/ */
struct module struct module
{ {
using inserter = std::function<instruction_ref(module& m, instruction_ref ins, const operation& op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args)>; using inserter = std::function<instruction_ref(module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args)>;
module(const std::string& name = ""); module(const std::string& name = "");
// move constructor // move constructor
......
...@@ -328,7 +328,9 @@ bool instruction::can_eval() const ...@@ -328,7 +328,9 @@ bool instruction::can_eval() const
} }
else if(is_context_free(op)) else if(is_context_free(op))
{ {
assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) { return std::addressof(*arg) == this; })); assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) {
return std::addressof(*arg) == this;
}));
return std::all_of( return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); }); this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
} }
......
...@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins, ...@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions, const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins), default_module_inserter()); return insert_generic_instructions(
*this, ins, instructions, std::move(map_ins), default_module_inserter());
} }
std::vector<instruction_ref> std::vector<instruction_ref>
module::insert_instructions(module::inserter insert, instruction_ref ins, module::insert_instructions(module::inserter insert,
instruction_ref ins,
const std::vector<instruction_ref>& instructions, const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
...@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins, ...@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins,
const_module_ref m, const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter()); return insert_generic_instructions(
*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter());
} }
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins, ...@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
auto r = range(start, last); auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter()); return insert_generic_instructions(
*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter());
} }
instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
......
...@@ -917,10 +917,11 @@ struct find_broadcast_reshaper ...@@ -917,10 +917,11 @@ struct find_broadcast_reshaper
struct find_poinwise_reduce_reshape struct find_poinwise_reduce_reshape
{ {
template<class... Ms> template <class... Ms>
static auto match_reshaper(Ms... ms) static auto match_reshaper(Ms... ms)
{ {
return match::name({"reshape", "squeeze", "unsqueeze"})(match::arg(0)(match::skip(match::name("contiguous"))(ms...))); return match::name({"reshape", "squeeze", "unsqueeze"})(
match::arg(0)(match::skip(match::name("contiguous"))(ms...)));
} }
auto matcher() const auto matcher() const
{ {
...@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape ...@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape
return contains({"broadcast", "multibroadcast"}, op.name()); return contains({"broadcast", "multibroadcast"}, op.name());
} }
static bool is_broadcast(instruction_ref ins) static bool is_broadcast(instruction_ref ins) { return is_broadcast(ins->get_operator()); }
{
return is_broadcast(ins->get_operator());
}
static bool is_pointwise(instruction_ref ins) static bool is_pointwise(instruction_ref ins)
{ {
...@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape ...@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false); return a.get("pointwise", false);
} }
static bool is_reduce(instruction_ref ins) static bool is_reduce(instruction_ref ins) { return is_reduce(ins->get_operator()); }
{
return is_reduce(ins->get_operator());
}
static bool is_reduce(const operation& op) static bool is_reduce(const operation& op)
{ {
...@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape ...@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false) or a.get("reduce", false); return a.get("pointwise", false) or a.get("reduce", false);
} }
static std::vector<instruction_ref> topo_sort(instruction_ref entry, const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux) static std::vector<instruction_ref> topo_sort(instruction_ref entry,
const std::unordered_set<instruction_ref>& inss,
std::unordered_set<instruction_ref>& aux)
{ {
std::vector<instruction_ref> instructions; std::vector<instruction_ref> instructions;
bool has_entry = contains(inss, entry); bool has_entry = contains(inss, entry);
fix([&](auto self, instruction_ref ins) { fix([&](auto self, instruction_ref ins) {
if (ins != entry or has_entry) if(ins != entry or has_entry)
instructions.push_back(ins); instructions.push_back(ins);
for(auto input:ins->inputs()) for(auto input : ins->inputs())
{ {
if(not contains(inss, input)) if(not contains(inss, input))
aux.insert(input); aux.insert(input);
} }
for(auto output : ins->outputs()) for(auto output : ins->outputs())
{ {
if (contains(instructions, output)) if(contains(instructions, output))
continue; continue;
if (not contains(inss, output)) if(not contains(inss, output))
continue; continue;
self(output); self(output);
} }
...@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape ...@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape
return instructions; return instructions;
} }
static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux) static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss,
std::unordered_set<instruction_ref>& aux)
{ {
std::vector<instruction_ref> instructions; std::vector<instruction_ref> instructions;
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
for(auto ins:inss) for(auto ins : inss)
{ {
fix([&](auto self, instruction_ref child) { fix([&](auto self, instruction_ref child) {
if (contains(visited, child)) if(contains(visited, child))
return; return;
for(auto output:child->outputs()) for(auto output : child->outputs())
{ {
if (not contains(inss, output)) if(not contains(inss, output))
continue; continue;
self(output); self(output);
} }
visited.insert(child); visited.insert(child);
for(auto input:child->inputs()) for(auto input : child->inputs())
{ {
if(not contains(inss, input)) if(not contains(inss, input))
aux.insert(input); aux.insert(input);
...@@ -1025,11 +1023,11 @@ struct find_poinwise_reduce_reshape ...@@ -1025,11 +1023,11 @@ struct find_poinwise_reduce_reshape
auto reshape_ins = r.instructions["reshape"]; auto reshape_ins = r.instructions["reshape"];
auto nelements = x_ins->get_shape().elements(); auto nelements = x_ins->get_shape().elements();
auto dims1 = x_ins->get_shape().lens(); auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens(); auto dims2 = reshape_ins->get_shape().lens();
auto cd = common_dims::compute(dims1, dims2); auto cd = common_dims::compute(dims1, dims2);
if (cd.empty()) if(cd.empty())
return; return;
// m.debug_print(); // m.debug_print();
...@@ -1064,28 +1062,30 @@ struct find_poinwise_reduce_reshape ...@@ -1064,28 +1062,30 @@ struct find_poinwise_reduce_reshape
// Collect from output // Collect from output
fix([&](auto self, instruction_ref out) { fix([&](auto self, instruction_ref out) {
// if(contains(inss, out)) // if(contains(inss, out))
// return; // return;
// std::cout << "Visit: "; // std::cout << "Visit: ";
// m.debug_print(out); // m.debug_print(out);
// m.debug_print(out->inputs()); // m.debug_print(out->inputs());
auto outputs = out->outputs(); auto outputs = out->outputs();
std::sort(outputs.begin(), outputs.end(), by(std::less<>{}, [&](instruction_ref i) { std::sort(outputs.begin(), outputs.end(), by(std::less<>{}, [&](instruction_ref i) {
return std::distance(reshape_ins, i); return std::distance(reshape_ins, i);
})); }));
// m.debug_print(outputs); // m.debug_print(outputs);
for(auto output : outputs) for(auto output : outputs)
{ {
if(not std::all_of( if(not std::all_of(
output->inputs().begin(), output->inputs().end(), [&](auto input) { output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or reshape_ins == input or contains(output_inss, input);// or dom.strictly_dominate(reshape_ins, input); return input->can_eval() or reshape_ins == input or
contains(output_inss,
input); // or dom.strictly_dominate(reshape_ins, input);
})) }))
continue; continue;
if(not is_pointwise_or_reduce(output) and not is_broadcast(output)) if(not is_pointwise_or_reduce(output) and not is_broadcast(output))
continue; continue;
if (is_reduce(output)) if(is_reduce(output))
{ {
auto op_axes = output->get_operator().to_value()["axes"].to_vector<int64_t>(); auto op_axes = output->get_operator().to_value()["axes"].to_vector<int64_t>();
if (axes.empty()) if(axes.empty())
axes = op_axes; axes = op_axes;
if(axes != op_axes) if(axes != op_axes)
return; return;
...@@ -1096,18 +1096,19 @@ struct find_poinwise_reduce_reshape ...@@ -1096,18 +1096,19 @@ struct find_poinwise_reduce_reshape
})(reshape_ins); })(reshape_ins);
std::vector<int64_t> common_axes; std::vector<int64_t> common_axes;
for(auto axis:axes) for(auto axis : axes)
{ {
common_axes.insert(common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end()); common_axes.insert(
common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end());
} }
auto common_rdims = cd.dims; auto common_rdims = cd.dims;
for(auto axis:common_axes) for(auto axis : common_axes)
{ {
common_rdims[axis] = 1; common_rdims[axis] = 1;
} }
// Topological sort // Topological sort
std::unordered_set<instruction_ref> aux; std::unordered_set<instruction_ref> aux;
auto input_instructions = topo_sort(input_inss, aux); auto input_instructions = topo_sort(input_inss, aux);
auto output_instructions = topo_sort(output_inss, aux); auto output_instructions = topo_sort(output_inss, aux);
// std::cout << "output_inss:\n"; // std::cout << "output_inss:\n";
// m.debug_print({output_inss.begin(), output_inss.end()}); // m.debug_print({output_inss.begin(), output_inss.end()});
...@@ -1116,23 +1117,28 @@ struct find_poinwise_reduce_reshape ...@@ -1116,23 +1117,28 @@ struct find_poinwise_reduce_reshape
// std::cout << "aux:\n"; // std::cout << "aux:\n";
// m.debug_print({aux.begin(), aux.end()}); // m.debug_print({aux.begin(), aux.end()});
auto last = output_instructions.back(); auto last = output_instructions.back();
auto insert_reshape = [&](instruction_ref input) { auto insert_reshape = [&](instruction_ref input) {
auto use_rdims = input->get_shape().elements() < nelements; auto use_rdims = input->get_shape().elements() < nelements;
auto c = m.insert_instruction(last, make_op("contiguous"), input); auto c = m.insert_instruction(last, make_op("contiguous"), input);
return m.insert_instruction(last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c); return m.insert_instruction(
last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c);
}; };
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
// map_ins[entry] = insert_reshape(entry); // map_ins[entry] = insert_reshape(entry);
for(auto i:aux) for(auto i : aux)
{ {
map_ins[i] = insert_reshape(i); map_ins[i] = insert_reshape(i);
} }
auto inserter = [&](module& mm, instruction_ref i, operation op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args) { auto inserter = [&](module& mm,
if (is_reduce(op)) instruction_ref i,
operation op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args) {
if(is_reduce(op))
op.from_value({{"axes", common_axes}}); op.from_value({{"axes", common_axes}});
if (is_broadcast(op)) if(is_broadcast(op))
op.from_value({{"out_lens", cd.dims}}); op.from_value({{"out_lens", cd.dims}});
// std::cout << op << std::endl; // std::cout << op << std::endl;
// m.debug_print(args); // m.debug_print(args);
...@@ -1141,7 +1147,7 @@ struct find_poinwise_reduce_reshape ...@@ -1141,7 +1147,7 @@ struct find_poinwise_reduce_reshape
auto new_x_ins = m.insert_instructions(inserter, last, input_instructions, map_ins).front(); auto new_x_ins = m.insert_instructions(inserter, last, input_instructions, map_ins).front();
map_ins[reshape_ins] = new_x_ins; map_ins[reshape_ins] = new_x_ins;
auto new_last = m.insert_instructions(inserter, last, output_instructions, map_ins).front(); auto new_last = m.insert_instructions(inserter, last, output_instructions, map_ins).front();
auto new_c = m.insert_instruction(last, make_op("contiguous"), new_last); auto new_c = m.insert_instruction(last, make_op("contiguous"), new_last);
auto new_reshape = m.insert_instruction(last, make_op("reshape", {{"dims", dims2}}), new_c); auto new_reshape = m.insert_instruction(last, make_op("reshape", {{"dims", dims2}}), new_c);
m.debug_print(); m.debug_print();
m.debug_print(last); m.debug_print(last);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment