/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { static bool try_compute_shape(instruction_ref ins, const std::vector& inputs, const std::vector& mods) { try { shape new_shape = ins->get_operator().compute_shape(inputs, mods); // Cannot tell if a dynamic shape will need to be made contiguous if(new_shape.dynamic()) { return false; } // If the output shape is a standard shape, no need to try its output if(new_shape.standard()) { return true; } // if no changes for the shape, the contiguous can also be removed if(new_shape == ins->get_shape()) { return true; } auto outputs = ins->outputs(); // If the current instruction has no output, it means it is the last // instruction and generates a non-standard output shape, and the last // output shape is different from the case with the contiguous operator if(outputs.empty()) { return false; } for(auto output : outputs) { auto args = output->inputs(); std::vector input_shapes(args.size()); std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) { return (arg == ins) ? new_shape : arg->get_shape(); }); if(not try_compute_shape(output, input_shapes, mods)) { return false; } } } catch(...) { return false; } return true; } static bool try_compute_shape(instruction_ref ins, const std::vector& args, const std::vector& mods) { auto inputs = to_shapes(args); return try_compute_shape(ins, inputs, mods); } template static void remove_contiguous(const std::string& op_name, module& m, F f) { auto last = std::prev(m.end()); std::vector const_instructions; for(auto ins : iterator_for(m)) { // return instruction should have inputs with standard shape if(ins->name() == "@return") continue; if(ins != last and ins->outputs().empty()) continue; if(not f(ins)) continue; // Make a copy so we can modify it while we iterate auto args = ins->inputs(); auto new_args = args; auto mod_args = ins->module_inputs(); for(auto arg : ins->inputs()) { if(arg->name() != op_name) continue; auto prev = arg->inputs().front(); replace(new_args, arg, prev); if(try_compute_shape(ins, new_args, mod_args)) { instruction::replace_argument(ins, arg, prev); } else if(prev->can_eval()) { const_instructions.push_back(arg); } } } // Perform static contiguous evaluations in parallel std::vector literals(const_instructions.size()); par_for(const_instructions.size(), 1, [&](const auto i) { auto c = op::contiguous{}; auto prev = const_instructions[i]->inputs().front(); // compute the output contiguous shape from the previous instruction shape shape computed_shape = c.compute_shape({prev->get_shape()}); const std::vector& prev_eval = {prev->eval()}; // prev_eval should not be used in make_compute_output_shape() as computed_shape is static auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval)); literals[i] = c.compute(co_shape, prev_eval); }); // Replace static contiguous operations with a literal for(size_t i = 0; i < const_instructions.size(); i++) { auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); m.replace_instruction(const_instructions[i], l); } } static void remove_contiguous_noops(const std::string& op_name, module& m) { for(auto ins : iterator_for(m)) { if (ins->name() != op_name) continue; if(ins->inputs().front()->get_shape() != ins->get_shape()) continue; m.replace_instruction(ins, ins->inputs().front()); } } void eliminate_contiguous::apply(module& m) const { // Skip contiguous from splits first remove_contiguous(op_name, m, [](auto ins) { if(ins->name() != "slice") return true; return (ins->inputs().front()->outputs().size() == 1); }); remove_contiguous(op_name, m, [](auto) { return true; }); remove_contiguous_noops(op_name, m); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx