/* * The MIT License (MIT) * * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { /** * Convert 2 input static shape broadcast/multibroadcast into 1 input version. * Some compiler passes (ex. simplify_algebra) only support the 1 input versions * of the broadcasting operators. */ struct find_static_2in_broadcasts { auto matcher() const { return match::broadcast(match::nargs(2), match::arg(0)(match::static_shape()), match::arg(1)(match::static_shape())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto out_lens = ins->get_shape().lens(); auto broadcast_op = ins->get_operator(); if(broadcast_op.name() == "broadcast") { broadcast_op.from_value({{"out_lens", out_lens}}); } else { broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}}); } m.replace_instruction(ins, broadcast_op, ins->inputs().at(0)); } }; /** * Simplify slice with variable `starts` and `ends` to the constant version if * the `input_starts` and `input_ends` inputs are constant. */ struct find_const_3in_slice { auto matcher() const { return match::name("slice")(match::nargs(3), match::arg(1)(match::is_constant()), match::arg(2)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); argument starts_arg = inputs.at(1)->eval(); argument ends_arg = inputs.at(2)->eval(); if(not starts_arg.empty() and not ends_arg.empty()) { std::vector starts_vec; std::vector ends_vec; starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); auto slice_val = ins->get_operator().to_value(); auto axes_vec = slice_val.at("axes").to_vector(); m.replace_instruction( ins, make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), inputs.at(0)); } } }; /** * Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if * the `input_starts`, `input_ends`, and `input_axes` inputs are constant. */ struct find_const_4in_slice { auto matcher() const { return match::name("slice")(match::nargs(4), match::arg(1)(match::is_constant()), match::arg(2)(match::is_constant()), match::arg(3)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); argument starts_arg = inputs.at(1)->eval(); argument ends_arg = inputs.at(2)->eval(); argument axes_arg = inputs.at(3)->eval(); if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty()) { std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); }); m.replace_instruction( ins, make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), inputs.at(0)); } } }; void simplify_dyn_ops::apply(module& m) const { match::find_matches( m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx