/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) bool skip_propogate(instruction_ref ins) { if(ins->name() == "contiguous") return skip_propogate(ins->inputs().front()); auto&& s = ins->get_shape(); if(s.broadcasted() and not s.scalar()) return true; if(s.scalar() and s.elements() != 1) return true; return false; } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } void propagate_constant::apply(module& m) const { std::unordered_set const_instrs; auto last = std::prev(m.end()); // Find instructions that can be evaluated to a literal for(auto i : iterator_for(m)) { const bool is_const = is_const_ins(i); if(is_const and i != last) continue; if(i == last and is_const) { const_instrs.insert(i); } else { std::copy_if(i->inputs().begin(), i->inputs().end(), std::inserter(const_instrs, const_instrs.begin()), [&](const instruction_ref ins) { return is_const_ins(ins) and ins->name() != "@literal"; }); } } // Compute literals in parallel std::vector const_instrs_vec{const_instrs.begin(), const_instrs.end()}; std::vector literals(const_instrs_vec.size()); // DEBUG // for(int i = 0; i < const_instrs_vec.size(); ++i) //{ // auto ins = const_instrs_vec[i]; // if(ins->get_shape().type() == shape::half_type) // { // auto inputs = ins->inputs(); // std::vector new_inputs(inputs.size()); // std::vector added_instructions; // std::transform(inputs.begin(), inputs.end(), new_inputs.begin(), [&](auto input) { // auto input_type = input->get_shape().type(); // if(input_type != shape::half_type and input_type != shape::float_type) // return input; // auto ai = m.add_instruction( // make_op("convert", {{"target_type", shape::double_type}}), input); // added_instructions.push_back(ai); // return ai; // }); // auto new_ins = m.add_instruction(ins->get_operator(), new_inputs); // added_instructions.push_back(new_ins); // auto after_convert = m.add_instruction( // make_op("convert", {{"target_type", ins->get_shape().type()}}), new_ins); // added_instructions.push_back(after_convert); // literals[i] = after_convert->eval(); // for(auto a_ins : added_instructions) // { // m.remove_instruction(a_ins); // } // } // else // { // literals[i] = const_instrs_vec[i]->eval(); // } // } // Original par_for(const_instrs_vec.size(), 1, [&](const auto i) { literals[i] = const_instrs_vec[i]->eval(); }); // Replace instructions in m for(size_t i = 0; i < const_instrs_vec.size(); i++) { if(not literals[i].empty()) { if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{})) { std::cout << "Constant replace: " << std::endl; std::vector inss; fix([&](auto self, auto ins) { if(contains(inss, ins)) return; for(auto input : ins->inputs()) self(input); inss.push_back(ins); })(const_instrs_vec[i]); m.debug_print(inss); } assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); m.replace_instruction(const_instrs_vec[i], l); } } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx