Commit 22cee7ff authored by Paul's avatar Paul
Browse files

Format

parent d0dbaf41
...@@ -12,10 +12,7 @@ struct common_dims ...@@ -12,10 +12,7 @@ struct common_dims
{ {
static common_dims compute(const std::vector<std::size_t>& dims1, static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2); const std::vector<std::size_t>& dims2);
bool empty() const bool empty() const { return dims.empty(); }
{
return dims.empty();
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1; std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2; std::vector<std::vector<std::size_t>> axes_map2;
......
...@@ -203,13 +203,13 @@ struct basic_matcher ...@@ -203,13 +203,13 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_basic_fun_matcher([=](matcher_context& ctx, return make_basic_fun_matcher(
instruction_ref ins) -> optional<instruction_ref> { [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result) if(result)
{ {
bool matches = bool matches = +fold(
+ fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...); [&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches) if(matches)
return result; return result;
} }
......
...@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins ...@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/ */
struct module struct module
{ {
using inserter = std::function<instruction_ref(module& m, instruction_ref ins, const operation& op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args)>; using inserter = std::function<instruction_ref(module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args)>;
module(const std::string& name = ""); module(const std::string& name = "");
// move constructor // move constructor
......
...@@ -328,7 +328,9 @@ bool instruction::can_eval() const ...@@ -328,7 +328,9 @@ bool instruction::can_eval() const
} }
else if(is_context_free(op)) else if(is_context_free(op))
{ {
assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) { return std::addressof(*arg) == this; })); assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) {
return std::addressof(*arg) == this;
}));
return std::all_of( return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); }); this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
} }
......
...@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins, ...@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions, const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins), default_module_inserter()); return insert_generic_instructions(
*this, ins, instructions, std::move(map_ins), default_module_inserter());
} }
std::vector<instruction_ref> std::vector<instruction_ref>
module::insert_instructions(module::inserter insert, instruction_ref ins, module::insert_instructions(module::inserter insert,
instruction_ref ins,
const std::vector<instruction_ref>& instructions, const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
...@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins, ...@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins,
const_module_ref m, const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter()); return insert_generic_instructions(
*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter());
} }
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins, ...@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref> map_ins) std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
auto r = range(start, last); auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter()); return insert_generic_instructions(
*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter());
} }
instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
......
...@@ -917,10 +917,11 @@ struct find_broadcast_reshaper ...@@ -917,10 +917,11 @@ struct find_broadcast_reshaper
struct find_poinwise_reduce_reshape struct find_poinwise_reduce_reshape
{ {
template<class... Ms> template <class... Ms>
static auto match_reshaper(Ms... ms) static auto match_reshaper(Ms... ms)
{ {
return match::name({"reshape", "squeeze", "unsqueeze"})(match::arg(0)(match::skip(match::name("contiguous"))(ms...))); return match::name({"reshape", "squeeze", "unsqueeze"})(
match::arg(0)(match::skip(match::name("contiguous"))(ms...)));
} }
auto matcher() const auto matcher() const
{ {
...@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape ...@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape
return contains({"broadcast", "multibroadcast"}, op.name()); return contains({"broadcast", "multibroadcast"}, op.name());
} }
static bool is_broadcast(instruction_ref ins) static bool is_broadcast(instruction_ref ins) { return is_broadcast(ins->get_operator()); }
{
return is_broadcast(ins->get_operator());
}
static bool is_pointwise(instruction_ref ins) static bool is_pointwise(instruction_ref ins)
{ {
...@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape ...@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false); return a.get("pointwise", false);
} }
static bool is_reduce(instruction_ref ins) static bool is_reduce(instruction_ref ins) { return is_reduce(ins->get_operator()); }
{
return is_reduce(ins->get_operator());
}
static bool is_reduce(const operation& op) static bool is_reduce(const operation& op)
{ {
...@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape ...@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false) or a.get("reduce", false); return a.get("pointwise", false) or a.get("reduce", false);
} }
static std::vector<instruction_ref> topo_sort(instruction_ref entry, const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux) static std::vector<instruction_ref> topo_sort(instruction_ref entry,
const std::unordered_set<instruction_ref>& inss,
std::unordered_set<instruction_ref>& aux)
{ {
std::vector<instruction_ref> instructions; std::vector<instruction_ref> instructions;
bool has_entry = contains(inss, entry); bool has_entry = contains(inss, entry);
fix([&](auto self, instruction_ref ins) { fix([&](auto self, instruction_ref ins) {
if (ins != entry or has_entry) if(ins != entry or has_entry)
instructions.push_back(ins); instructions.push_back(ins);
for(auto input:ins->inputs()) for(auto input : ins->inputs())
{ {
if(not contains(inss, input)) if(not contains(inss, input))
aux.insert(input); aux.insert(input);
} }
for(auto output : ins->outputs()) for(auto output : ins->outputs())
{ {
if (contains(instructions, output)) if(contains(instructions, output))
continue; continue;
if (not contains(inss, output)) if(not contains(inss, output))
continue; continue;
self(output); self(output);
} }
...@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape ...@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape
return instructions; return instructions;
} }
static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux) static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss,
std::unordered_set<instruction_ref>& aux)
{ {
std::vector<instruction_ref> instructions; std::vector<instruction_ref> instructions;
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
for(auto ins:inss) for(auto ins : inss)
{ {
fix([&](auto self, instruction_ref child) { fix([&](auto self, instruction_ref child) {
if (contains(visited, child)) if(contains(visited, child))
return; return;
for(auto output:child->outputs()) for(auto output : child->outputs())
{ {
if (not contains(inss, output)) if(not contains(inss, output))
continue; continue;
self(output); self(output);
} }
visited.insert(child); visited.insert(child);
for(auto input:child->inputs()) for(auto input : child->inputs())
{ {
if(not contains(inss, input)) if(not contains(inss, input))
aux.insert(input); aux.insert(input);
...@@ -1029,7 +1027,7 @@ struct find_poinwise_reduce_reshape ...@@ -1029,7 +1027,7 @@ struct find_poinwise_reduce_reshape
auto dims2 = reshape_ins->get_shape().lens(); auto dims2 = reshape_ins->get_shape().lens();
auto cd = common_dims::compute(dims1, dims2); auto cd = common_dims::compute(dims1, dims2);
if (cd.empty()) if(cd.empty())
return; return;
// m.debug_print(); // m.debug_print();
...@@ -1077,15 +1075,17 @@ struct find_poinwise_reduce_reshape ...@@ -1077,15 +1075,17 @@ struct find_poinwise_reduce_reshape
{ {
if(not std::all_of( if(not std::all_of(
output->inputs().begin(), output->inputs().end(), [&](auto input) { output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or reshape_ins == input or contains(output_inss, input);// or dom.strictly_dominate(reshape_ins, input); return input->can_eval() or reshape_ins == input or
contains(output_inss,
input); // or dom.strictly_dominate(reshape_ins, input);
})) }))
continue; continue;
if(not is_pointwise_or_reduce(output) and not is_broadcast(output)) if(not is_pointwise_or_reduce(output) and not is_broadcast(output))
continue; continue;
if (is_reduce(output)) if(is_reduce(output))
{ {
auto op_axes = output->get_operator().to_value()["axes"].to_vector<int64_t>(); auto op_axes = output->get_operator().to_value()["axes"].to_vector<int64_t>();
if (axes.empty()) if(axes.empty())
axes = op_axes; axes = op_axes;
if(axes != op_axes) if(axes != op_axes)
return; return;
...@@ -1096,12 +1096,13 @@ struct find_poinwise_reduce_reshape ...@@ -1096,12 +1096,13 @@ struct find_poinwise_reduce_reshape
})(reshape_ins); })(reshape_ins);
std::vector<int64_t> common_axes; std::vector<int64_t> common_axes;
for(auto axis:axes) for(auto axis : axes)
{ {
common_axes.insert(common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end()); common_axes.insert(
common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end());
} }
auto common_rdims = cd.dims; auto common_rdims = cd.dims;
for(auto axis:common_axes) for(auto axis : common_axes)
{ {
common_rdims[axis] = 1; common_rdims[axis] = 1;
} }
...@@ -1120,19 +1121,24 @@ struct find_poinwise_reduce_reshape ...@@ -1120,19 +1121,24 @@ struct find_poinwise_reduce_reshape
auto insert_reshape = [&](instruction_ref input) { auto insert_reshape = [&](instruction_ref input) {
auto use_rdims = input->get_shape().elements() < nelements; auto use_rdims = input->get_shape().elements() < nelements;
auto c = m.insert_instruction(last, make_op("contiguous"), input); auto c = m.insert_instruction(last, make_op("contiguous"), input);
return m.insert_instruction(last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c); return m.insert_instruction(
last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c);
}; };
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
// map_ins[entry] = insert_reshape(entry); // map_ins[entry] = insert_reshape(entry);
for(auto i:aux) for(auto i : aux)
{ {
map_ins[i] = insert_reshape(i); map_ins[i] = insert_reshape(i);
} }
auto inserter = [&](module& mm, instruction_ref i, operation op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args) { auto inserter = [&](module& mm,
if (is_reduce(op)) instruction_ref i,
operation op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args) {
if(is_reduce(op))
op.from_value({{"axes", common_axes}}); op.from_value({{"axes", common_axes}});
if (is_broadcast(op)) if(is_broadcast(op))
op.from_value({{"out_lens", cd.dims}}); op.from_value({{"out_lens", cd.dims}});
// std::cout << op << std::endl; // std::cout << op << std::endl;
// m.debug_print(args); // m.debug_print(args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment