/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #ifndef MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP #define MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct nonmaxsuppression { bool center_point_box = false; bool use_dyn_output = false; template static auto reflect(Self& self, F f) { return pack(f(self.center_point_box, "center_point_box"), f(self.use_dyn_output, "use_dyn_output")); } std::string name() const { return "nonmaxsuppression"; } shape compute_shape(std::vector inputs) const { // requires at least 2 inputs check_shapes{{inputs.at(0), inputs.at(1)}, *this, true}.only_dims(3).same_ndims(); auto boxes_max_lens = inputs.at(0).max_lens(); // num batches * num boxes const auto max_num_boxes = boxes_max_lens.at(0) * boxes_max_lens.at(1); auto fixed_shape_error_check = [&]() { auto lens = inputs.front().lens(); if(lens[1] != inputs.at(1).lens()[2]) { MIGRAPHX_THROW( "NonMaxSuppression: spatial dimension mismatch between boxes and scores input"); } if(lens[0] != inputs.at(1).lens()[0]) { MIGRAPHX_THROW( "NonMaxSuppression: number of batches mismatch between boxes and scores input"); } }; if(use_dyn_output) { if(inputs.at(0).dynamic()) { // both boxes and scores should be dynamic // check dynamic dimensions are consistent const auto boxes_dims = inputs.at(0).dyn_dims(); const auto scores_dims = inputs.at(1).dyn_dims(); if(boxes_dims.at(1) != scores_dims.at(2)) { MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between " "boxes and scores input"); } if(boxes_dims.at(0) != scores_dims.at(0)) { MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between " "boxes and scores input"); } } else if(inputs.at(1).dynamic()) { // scores has dynamic shape, boxes fixed shape // check that it is only a dynamic number of classes const auto scores_dims = inputs.at(1).dyn_dims(); const auto boxes_lens = inputs.at(0).lens(); if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max != boxes_lens.at(0)) { MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; num_batches not " "fixed or mismatched"); } if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max != boxes_lens.at(1)) { MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; " "spatial_dimension not fixed or mismatches"); } } else { fixed_shape_error_check(); } std::vector out_lens = {}; out_lens.push_back({0, max_num_boxes, 0}); out_lens.push_back({3, 3, 0}); return {shape::int64_type, out_lens}; } else { if(inputs.at(0).dynamic() or inputs.at(1).dynamic()) { MIGRAPHX_THROW( "NonMaxSuppression: dynamic input shape with use_dyn_output set to false"); } fixed_shape_error_check(); std::vector out_lens = {max_num_boxes, 3}; return {shape::int64_type, out_lens}; } } struct box { std::array x; std::array y; void sort() { std::sort(x.begin(), x.end()); std::sort(y.begin(), y.end()); } std::array& operator[](std::size_t i) { return i == 0 ? x : y; } double area() const { assert(std::is_sorted(x.begin(), x.end())); assert(std::is_sorted(y.begin(), y.end())); return (x[1] - x[0]) * (y[1] - y[0]); } }; template box batch_box(T boxes, std::size_t box_idx) const { box result{}; auto start = boxes + 4 * box_idx; if(center_point_box) { double half_width = start[2] / 2.0; double half_height = start[3] / 2.0; double x_center = start[0]; double y_center = start[1]; result.x = {x_center - half_width, x_center + half_width}; result.y = {y_center - half_height, y_center + half_height}; } else { result.x = {static_cast(start[1]), static_cast(start[3])}; result.y = {static_cast(start[0]), static_cast(start[2])}; } return result; } inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const { b1.sort(); b2.sort(); box intersection{}; for(auto i : range(2)) { intersection[i][0] = std::max(b1[i][0], b2[i][0]); intersection[i][1] = std::min(b1[i][1], b2[i][1]); } std::vector> bbox = {intersection.x, intersection.y}; if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { return not std::is_sorted(bx.begin(), bx.end()); })) { return false; } const double area1 = b1.area(); const double area2 = b2.area(); const double intersection_area = intersection.area(); const double union_area = area1 + area2 - intersection_area; if(area1 <= .0f or area2 <= .0f or union_area <= .0f) { return false; } const double intersection_over_union = intersection_area / union_area; return intersection_over_union > iou_threshold; } // filter boxes below score_threshold template std::priority_queue> filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const { std::priority_queue> boxes_heap; auto insert_to_boxes_heap = make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); }); int64_t box_idx = 0; transform_if( scores_start, scores_start + num_boxes, insert_to_boxes_heap, [&](auto sc) { box_idx++; return sc >= score_threshold; }, [&](auto sc) { return std::make_pair(sc, box_idx - 1); }); return boxes_heap; } template std::size_t compute_nms(Output output, Boxes boxes, Scores scores, const shape& max_output_shape, std::size_t max_output_boxes_per_class, double iou_threshold, double score_threshold) const { std::fill(output.begin(), output.end(), 0); const auto& lens = scores.get_shape().lens(); const auto num_batches = lens[0]; const auto num_classes = lens[1]; const auto num_boxes = lens[2]; // boxes of a class with NMS applied [score, index] std::vector> selected_boxes_inside_class; std::vector selected_indices; selected_boxes_inside_class.reserve(max_output_shape.elements()); // iterate over batches and classes shape comp_s{shape::double_type, {num_batches, num_classes}}; shape_for_each(comp_s, [&](auto idx) { auto batch_idx = idx[0]; auto class_idx = idx[1]; // index offset for this class auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes; // iterator to first value of this batch auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4; auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold); selected_boxes_inside_class.clear(); // Get the next box with top score, filter by iou_threshold while(not boxes_heap.empty() && selected_boxes_inside_class.size() < max_output_boxes_per_class) { // Check with existing selected boxes for this class, remove box if it // exceeds the IOU (Intersection Over Union) threshold const auto next_top_score = boxes_heap.top(); bool not_selected = std::any_of(selected_boxes_inside_class.begin(), selected_boxes_inside_class.end(), [&](auto selected_index) { return this->suppress_by_iou( batch_box(batch_boxes_start, next_top_score.second), batch_box(batch_boxes_start, selected_index.second), iou_threshold); }); if(not not_selected) { selected_boxes_inside_class.push_back(next_top_score); selected_indices.push_back(batch_idx); selected_indices.push_back(class_idx); selected_indices.push_back(next_top_score.second); } boxes_heap.pop(); } }); std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); return selected_indices.size() / 3; } argument compute(const shape& output_shape, std::vector args) const { // make buffer of maximum size shape max_output_shape = {output_shape.type(), output_shape.max_lens()}; argument result{max_output_shape}; std::size_t max_output_boxes_per_class = (args.size() > 2) ? (args.at(2).at()) : 0; if(max_output_boxes_per_class == 0) { return result; } double iou_threshold = (args.size() > 3) ? (args.at(3).at()) : 0.0f; double score_threshold = (args.size() > 4) ? (args.at(4).at()) : 0.0f; std::size_t num_selected = 0; result.visit([&](auto output) { visit_all(args[0], args[1])([&](auto boxes, auto scores) { num_selected = compute_nms(output, boxes, scores, max_output_shape, max_output_boxes_per_class, iou_threshold, score_threshold); }); }); if(use_dyn_output) { return result.reshape({output_shape.type(), {num_selected, 3}}); } else { return result; } } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif