Commit d0dbaf41 authored by Paul's avatar Paul
Browse files

Save code

parent da78b0c0
......@@ -46,7 +46,15 @@ bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins
return false;
}
struct module_visitor
struct module_input_visitor
{
module* mm;
module& get_nodes() const { return *mm; }
const std::vector<instruction_ref>& get_children(instruction_ref ins) { return ins->inputs(); }
};
struct module_output_visitor
{
module* mm;
module& get_nodes() const { return *mm; }
......@@ -93,7 +101,12 @@ dominator_info compute_dominator_generic(Visitor v)
dominator_info compute_dominator(module& m)
{
return compute_dominator_generic(module_visitor{&m});
return compute_dominator_generic(module_input_visitor{&m});
}
dominator_info compute_post_dominator(module& m)
{
return compute_dominator_generic(module_output_visitor{&m});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -12,6 +12,10 @@ struct common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
bool empty() const
{
return dims.empty();
}
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
......
......@@ -42,7 +42,7 @@ struct dominator_info
};
dominator_info compute_dominator(module& m);
// dominator_info compute_dominator_naive(const module& m);
dominator_info compute_post_dominator(module& m);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -198,8 +198,8 @@ struct basic_matcher
{
M m;
template <class... Ts>
auto operator()(Ts... ms) const
template <class... Ms>
auto operator()(Ms... ms) const
{
// Copy m because we cant capture `this` by value
auto mm = m;
......@@ -209,7 +209,7 @@ struct basic_matcher
if(result)
{
bool matches =
fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
+ fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
......
......@@ -54,6 +54,7 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/
struct module
{
using inserter = std::function<instruction_ref(module& m, instruction_ref ins, const operation& op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args)>;
module(const std::string& name = "");
// move constructor
......@@ -137,6 +138,12 @@ struct module
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(inserter insert,
instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const_module_ref m,
......
......@@ -328,6 +328,7 @@ bool instruction::can_eval() const
}
else if(is_context_free(op))
{
assert(std::none_of(this->inputs().begin(), this->inputs().end(), [&](instruction_ref arg) { return std::addressof(*arg) == this; }));
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
......
......@@ -197,12 +197,13 @@ void module::assign(const module& m)
}
}
template <class Range>
template <class Range, class Inserter>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
std::unordered_map<instruction_ref, instruction_ref> map_ins,
Inserter insert)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
......@@ -244,7 +245,7 @@ insert_generic_instructions(module& m,
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
copy_ins = insert(m, ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
......@@ -253,6 +254,13 @@ insert_generic_instructions(module& m,
return mod_outputs;
}
static auto default_module_inserter()
{
return [](module& m, auto&&... xs) {
return m.insert_instruction(static_cast<decltype(xs)&&>(xs)...);
};
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......@@ -422,7 +430,15 @@ module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins), default_module_inserter());
}
std::vector<instruction_ref>
module::insert_instructions(module::inserter insert, instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins), insert);
}
std::vector<instruction_ref>
......@@ -430,7 +446,7 @@ module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins), default_module_inserter());
}
std::vector<instruction_ref>
......@@ -440,7 +456,7 @@ module::insert_instructions(instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins), default_module_inserter());
}
instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
......
......@@ -37,6 +37,8 @@
#include <unordered_set>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common_dims.hpp>
#include <migraphx/dom_info.hpp>
#include <map>
......@@ -915,16 +917,29 @@ struct find_broadcast_reshaper
struct find_poinwise_reduce_reshape
{
template<class... Ms>
static auto match_reshaper(Ms... ms)
{
return match::name({"reshape", "squeeze", "unsqueeze"})(match::arg(0)(match::skip(match::name("contiguous"))(ms...)));
}
auto matcher() const
{
auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"});
auto skip_contiguous = match::skip(match::name("contiguous"));
auto pointwise_or_reduce = match::any_of(match::pointwise(), match::reduce());
auto reshape_pointwise_or_reduce =
reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
match_reshaper(match::pointwise().bind("x")).bind("reshape");
return pointwise_or_reduce(match::any_of[match::inputs()](reshape_pointwise_or_reduce));
}
static bool is_broadcast(const operation& op)
{
return contains({"broadcast", "multibroadcast"}, op.name());
}
static bool is_broadcast(instruction_ref ins)
{
return is_broadcast(ins->get_operator());
}
static bool is_pointwise(instruction_ref ins)
{
auto a = ins->get_operator().attributes();
......@@ -933,7 +948,12 @@ struct find_poinwise_reduce_reshape
static bool is_reduce(instruction_ref ins)
{
auto a = ins->get_operator().attributes();
return is_reduce(ins->get_operator());
}
static bool is_reduce(const operation& op)
{
auto a = op.attributes();
return a.get("reduce", false);
}
......@@ -943,27 +963,87 @@ struct find_poinwise_reduce_reshape
return a.get("pointwise", false) or a.get("reduce", false);
}
static std::vector<instruction_ref> topo_sort(instruction_ref entry, const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux)
{
std::vector<instruction_ref> instructions;
bool has_entry = contains(inss, entry);
fix([&](auto self, instruction_ref ins) {
if (ins != entry or has_entry)
instructions.push_back(ins);
for(auto input:ins->inputs())
{
if(not contains(inss, input))
aux.insert(input);
}
for(auto output : ins->outputs())
{
if (contains(instructions, output))
continue;
if (not contains(inss, output))
continue;
self(output);
}
})(entry);
assert(instructions.size() == inss.size());
return instructions;
}
static std::vector<instruction_ref> topo_sort(const std::unordered_set<instruction_ref>& inss, std::unordered_set<instruction_ref>& aux)
{
std::vector<instruction_ref> instructions;
std::unordered_set<instruction_ref> visited;
for(auto ins:inss)
{
fix([&](auto self, instruction_ref child) {
if (contains(visited, child))
return;
for(auto output:child->outputs())
{
if (not contains(inss, output))
continue;
self(output);
}
visited.insert(child);
for(auto input:child->inputs())
{
if(not contains(inss, input))
aux.insert(input);
}
instructions.push_back(child);
})(ins);
}
std::reverse(instructions.begin(), instructions.end());
assert(instructions.size() == inss.size());
return instructions;
}
void apply(module& m, const match::matcher_result& r) const
{
// std::cout << "find_poinwise_reduce_reshape" << std::endl;
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto nelements = x_ins->get_shape().elements();
auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens();
std::vector<int64_t> axes;
if(x_ins->get_operator().attributes().get("reduce", false))
{
axes = x_ins->get_operator().to_value()["axes"].to_vector<int64_t>();
}
std::unordered_set<instruction_ref> inss;
instruction_ref entry;
auto cd = common_dims::compute(dims1, dims2);
if (cd.empty())
return;
// m.debug_print();
// m.debug_print(reshape_ins);
// m.debug_print(ins);
// Collect from inputs
std::unordered_set<instruction_ref> input_inss;
instruction_ref entry;
fix([&](auto self, instruction_ref i) {
inss.insert(i);
if(contains(input_inss, i))
return;
input_inss.insert(i);
entry = i;
auto pointwise_or_reduce = [&](instruction_ref input) {
auto pointwise_or_reduce = [](instruction_ref input) {
if(input->can_eval())
return false;
return is_pointwise(input);
......@@ -977,37 +1057,97 @@ struct find_poinwise_reduce_reshape
return;
self(*it);
})(x_ins);
std::vector<int64_t> axes;
auto dom = compute_post_dominator(m);
std::unordered_set<instruction_ref> output_inss;
// Collect from output
fix([&](auto self, instruction_ref out) {
for(auto output : out->outputs())
// if(contains(inss, out))
// return;
// std::cout << "Visit: ";
// m.debug_print(out);
// m.debug_print(out->inputs());
auto outputs = out->outputs();
std::sort(outputs.begin(), outputs.end(), by(std::less<>{}, [&](instruction_ref i) {
return std::distance(reshape_ins, i);
}));
// m.debug_print(outputs);
for(auto output : outputs)
{
if(not std::all_of(
output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or contains(inss, input);
return input->can_eval() or reshape_ins == input or contains(output_inss, input);// or dom.strictly_dominate(reshape_ins, input);
}))
continue;
if(not is_pointwise_or_reduce(ins))
if(not is_pointwise_or_reduce(output) and not is_broadcast(output))
continue;
inss.insert(output);
self(output);
}
})(x_ins);
std::vector<instruction_ref> instructions;
std::unordered_set<instruction_ref> aux;
// Topological sort
fix([&](auto self, instruction_ref i) {
instructions.push_back(i);
for(auto output : i->outputs())
{
if(not contains(inss, output))
if (is_reduce(output))
{
aux.insert(output);
continue;
auto op_axes = output->get_operator().to_value()["axes"].to_vector<int64_t>();
if (axes.empty())
axes = op_axes;
if(axes != op_axes)
return;
}
output_inss.insert(output);
self(output);
}
})(entry);
})(reshape_ins);
std::vector<int64_t> common_axes;
for(auto axis:axes)
{
common_axes.insert(common_axes.end(), cd.axes_map2[axis].begin(), cd.axes_map2[axis].end());
}
auto common_rdims = cd.dims;
for(auto axis:common_axes)
{
common_rdims[axis] = 1;
}
// Topological sort
std::unordered_set<instruction_ref> aux;
auto input_instructions = topo_sort(input_inss, aux);
auto output_instructions = topo_sort(output_inss, aux);
// std::cout << "output_inss:\n";
// m.debug_print({output_inss.begin(), output_inss.end()});
// std::cout << "Output instructions:\n";
// m.debug_print(output_instructions);
// std::cout << "aux:\n";
// m.debug_print({aux.begin(), aux.end()});
auto last = output_instructions.back();
auto insert_reshape = [&](instruction_ref input) {
auto use_rdims = input->get_shape().elements() < nelements;
auto c = m.insert_instruction(last, make_op("contiguous"), input);
return m.insert_instruction(last, make_op("reshape", {{"dims", use_rdims ? common_rdims : cd.dims}}), c);
};
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// map_ins[entry] = insert_reshape(entry);
for(auto i:aux)
{
map_ins[i] = insert_reshape(i);
}
auto inserter = [&](module& mm, instruction_ref i, operation op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_args) {
if (is_reduce(op))
op.from_value({{"axes", common_axes}});
if (is_broadcast(op))
op.from_value({{"out_lens", cd.dims}});
// std::cout << op << std::endl;
// m.debug_print(args);
return mm.insert_instruction(i, op, args, module_args);
};
auto new_x_ins = m.insert_instructions(inserter, last, input_instructions, map_ins).front();
map_ins[reshape_ins] = new_x_ins;
auto new_last = m.insert_instructions(inserter, last, output_instructions, map_ins).front();
auto new_c = m.insert_instruction(last, make_op("contiguous"), new_last);
auto new_reshape = m.insert_instruction(last, make_op("reshape", {{"dims", dims2}}), new_c);
m.debug_print();
m.debug_print(last);
m.debug_print(new_reshape);
m.replace_instruction(last, new_reshape);
std::abort();
}
};
......@@ -1020,7 +1160,7 @@ void simplify_reshapes::apply(module& m) const
find_resize{},
find_nop_reshapes{},
find_reshaper{},
find_broadcast_reshaper{},
// find_broadcast_reshaper{},
// find_reshape_cont{},
find_transpose{},
find_concat_transpose{},
......@@ -1032,7 +1172,8 @@ void simplify_reshapes::apply(module& m) const
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{},
find_mul_add_transpose_contiguous_reshaper_gemm{},
find_reshape_gemm{});
find_reshape_gemm{},
find_poinwise_reduce_reshape{});
dead_code_elimination{}.apply(m);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment