/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { void calc_reflect_indices(std::vector& indices, const int64_t num_dims) { int k = 0; bool reversed = false; // in reflect padding, if the num_pads > num_dims, // compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0) for(int& idx : indices) { if(k == num_dims - 1) reversed = true; if(k == 0) reversed = false; if(reversed) k--; else k++; idx = k; } } instruction_ref reflect_pad(const onnx_parser::node_info& info, const std::vector& pads, instruction_ref input) { size_t num_dims = pads.size() / 2; std::vector ldims(pads.begin(), pads.begin() + num_dims); std::vector rdims(pads.begin() + num_dims, pads.end()); assert(ldims.size() == rdims.size()); std::vector axes(num_dims); std::iota(axes.begin(), axes.end(), int64_t{0}); // iterate over dimensions, starting from lowest dimension for(int64_t i = num_dims - 1; i >= 0; i--) { auto axis = i; auto lcount = ldims.at(i); auto rcount = rdims.at(i); if(lcount == 0 and rcount == 0) // no padding for current dim continue; // calculate starts and ends for each iteration since shape may change std::vector dims = input->get_shape().lens(); std::vector starts(axes.size(), 0); std::vector ends(dims.begin(), dims.end()); std::vector slices; auto starts_it = starts.begin() + i; auto ends_it = ends.begin() + i; auto dims_it = dims.begin() + i; std::vector l_indices(lcount); std::vector r_indices(rcount); // compute slice indices in a periodic fashion calc_reflect_indices(l_indices, *dims_it); calc_reflect_indices(r_indices, *dims_it); for(int idx : l_indices) { *starts_it = idx; *ends_it = *starts_it + 1; slices.push_back(info.add_instruction( make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input)); } // when padding on the left side, the outermost pad should be at the beginning std::reverse(slices.begin(), slices.end()); slices.push_back(input); for(int idx : r_indices) { *starts_it = *dims_it - idx - 1; *ends_it = *starts_it + 1; slices.push_back(info.add_instruction( make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input)); } input = info.add_instruction(make_op("concat", {{"axis", axis}}), slices); } return input; } struct parse_pad : op_parser { std::vector operators() const { return {{"Pad"}}; } instruction_ref parse(const op_desc& /*opd*/, const onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { std::vector pads{}; if(args.size() >= 2) { auto pad_arg = args.at(1)->eval(); check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant"); pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); }); } else if(contains(info.attributes, "pads")) { auto&& pad_vals = info.attributes["pads"].ints(); pads = std::vector(pad_vals.begin(), pad_vals.end()); } else { MIGRAPHX_THROW("PARSE_PAD: pad must be available"); } // check if padding is actually being done (at least one value is nonzero) if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) { return info.add_instruction(make_op("identity"), args.front()); } if(contains(info.attributes, "mode")) { auto mode = info.attributes.at("mode").s(); if(mode == "reflect") return reflect_pad(info, pads, args.front()); if(mode != "constant") { MIGRAPHX_THROW( "PARSE_PAD: migraphx currently only supports constant and reflect padding"); } } float value = 0.0f; // third input is the value if(args.size() == 3) { auto val_ins = args.at(2); if(not val_ins->can_eval()) { MIGRAPHX_THROW("PARSE_PAD: input value must be constant"); } auto val_arg = val_ins->eval(); if(val_arg.get_shape().elements() != 1) { MIGRAPHX_THROW("PARSE_PAD: value should contain only one element"); } value = val_arg.at(); } else if(contains(info.attributes, "value")) { value = parser.parse_value(info.attributes.at("value")).at(); } return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}), args.front()); } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx