#include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { static bool try_compute_shape(instruction_ref ins, const std::vector& inputs) { try { shape new_shape = ins->get_operator().compute_shape(inputs); // If the output shape is a standard shape, no need to try its output if(new_shape.standard()) { return true; } // if no changes for the shape, the contiguous can also be removed if(new_shape == ins->get_shape()) { return true; } auto outputs = ins->outputs(); // If the current instruction has no output, it means it is the last // instruction and generates a non-standard output shape, and the last // output shape is different from the case with the contiguous operator if(outputs.empty()) { return false; } for(auto output : outputs) { auto args = output->inputs(); std::vector input_shapes(args.size()); std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) { return (arg == ins) ? new_shape : arg->get_shape(); }); if(!try_compute_shape(output, input_shapes)) { return false; } } } catch(...) { return false; } return true; } static bool try_compute_shape(instruction_ref ins, const std::vector& args) { auto inputs = to_shapes(args); return try_compute_shape(ins, inputs); } void eliminate_contiguous::apply(module& p) const { for(auto ins : iterator_for(p)) { // return instruction should have inputs with standard shape if(ins->name() == "@return") continue; if (std::none_of(ins->inputs().begin(), ins->inputs().end(), [&](auto arg) { return arg->name() == op_name; })) continue; // Make a copy so we can modify it while we iterate auto args = ins->inputs(); auto new_args = args; std::transform(new_args.begin(), new_args.end(), new_args.begin(), [&](auto arg) { if(arg->name() == op_name) return arg->inputs().front(); else return arg; }); assert(args.size() == new_args.size()); if(try_compute_shape(ins, new_args)) { for(auto i:range(args.size())) { if (args[i] == new_args[i]) continue; instruction::replace_argument(ins, args[i], new_args[i]); } } else { for(auto arg : ins->inputs()) { if(arg->name() == op_name) { new_args = args; auto prev = arg->inputs().front(); replace(new_args, arg, prev); if(try_compute_shape(ins, new_args)) { instruction::replace_argument(ins, arg, prev); } else if(prev->can_eval()) { auto c = op::contiguous{}; auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); auto l = p.add_literal(r.get_shape(), r.data()); p.replace_instruction(arg, l); } } } } } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx