#include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { const auto& reshaper_names() { // clang-format off static const std::unordered_set names = { "reshape", "contiguous", "squeeze", "unsqueeze" }; // clang-format on return names; } bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); } instruction_ref find_transpose_input(instruction_ref ins) { if(ins->inputs().size() != 1) return ins; if(ins->inputs().front()->name() == "contiguous") return find_transpose_input(ins->inputs().front()); if(ins->inputs().front()->name() == "transpose") return ins->inputs().front(); return ins; } auto get_transpose_dims(instruction_ref ins) { return any_cast(ins->get_operator()).dims; } std::vector reorder_dims(std::vector dims, std::vector permutation) { std::vector result(dims.size()); assert(dims.size() == permutation.size()); for(std::size_t i = 0; i < dims.size(); i++) { result[i] = dims[permutation[i]]; } return result; } std::vector invert_permutation(const std::vector& permutation) { return reorder_dims(permutation, permutation); } bool is_no_transpose(const std::vector& dims) { if(dims.empty()) return true; if(dims.front() != 0) return false; return std::adjacent_find( dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end(); } template std::vector sort_permutation(const Vector& data, Op op) { std::vector result(data.size()); std::iota(result.begin(), result.end(), 0); std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); }); return result; } std::vector find_permutation(const shape& s) { return sort_permutation(s.strides(), std::greater<>{}); } struct find_reshaper { auto matcher() const { return match::name(reshaper_names())( match::any_of[match::outputs()](match::name(reshaper_names()))); } void apply(program& p, match::matcher_result mr) const { auto ins = mr.result; std::vector reshapes{ins}; while(is_reshaper(reshapes.back())) { assert(!reshapes.back()->inputs().empty()); assert(p.has_instruction(reshapes.back()->inputs().front())); auto input = reshapes.back()->inputs().front(); reshapes.push_back(input); } std::pair r{p.end(), p.end()}; for(auto start : iterator_for(reshapes)) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { return i->get_shape() == (*start)->get_shape() and i != (*start); }); if(last != reshapes.rend()) { r = std::make_pair(*start, *last); break; } } if(r.first != r.second) { p.replace_instruction(r.first, r.second); } } }; MIGRAPHX_PRED_MATCHER(is_transpose_output, instruction_ref start) { return fix([&](auto self, auto ins) { if(ins->outputs().size() != 1) return false; if(ins->outputs().front()->name() == "contiguous") return self(ins->outputs().front()); return ins->outputs().front()->name() == "transpose"; })(start); } struct find_transpose { auto matcher() const { return match::name("transpose")(match::none_of( match::skip_output(match::name("contiguous"))(match::name("transpose")))); } void apply(program& p, match::matcher_result mr) const { auto ins = mr.result; auto x = ins; auto t = ins; std::vector dims(ins->get_shape().lens().size()); std::iota(dims.begin(), dims.end(), 0); do { dims = reorder_dims(get_transpose_dims(t), dims); x = t; t = find_transpose_input(x); } while(x != t and t->name() == "transpose"); if(t == ins or t->name() != "transpose") return; if(is_no_transpose(dims)) { p.replace_instruction(ins, t->inputs().front()); } else { p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front()); } } }; struct find_concat_transpose { auto matcher() const { return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape())); } void apply(program& p, match::matcher_result mr) const { auto ins = mr.result; auto s = ins->inputs().front()->get_shape(); auto op = any_cast(ins->get_operator()); auto permutation = find_permutation(s); auto ipermutaion = invert_permutation(permutation); op.axis = ipermutaion[op.axis]; std::vector inputs; std::transform( ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); }); auto concat = p.insert_instruction(ins, op, inputs); auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat); p.replace_instruction(ins, t); } }; void simplify_reshapes::apply(program& p) const { auto end = std::prev(p.end()); for(auto ins : iterator_for(p)) { if(ins == end and ins->name() == "contiguous") continue; // Skip possible dead instructions if(ins->outputs().empty() and ins != end) continue; match::find_matches(p, ins, find_reshaper{}, find_transpose{}, find_concat_transpose{}); } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx