nonmaxsuppression.hpp 9.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#ifndef MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP
#define MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP

#include <cmath>
#include <queue>
#include <cstdint>
#include <iterator>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/output_iterator.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct nonmaxsuppression
{
    bool center_point_box = false;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.center_point_box, "center_point_box"));
    }

    std::string name() const { return "nonmaxsuppression"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        // requires at least 2 inputs
        check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
        auto lens = inputs.front().lens();

        // check input shape
        if(lens[1] != inputs.at(1).lens()[2])
        {
65
66
67
68
69
70
71
72
73
            MIGRAPHX_THROW(
                "NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
        }

        // check batch sizes
        if(lens[0] != inputs.at(1).lens()[0])
        {
            MIGRAPHX_THROW(
                "NonMaxSuppression: number of batches mismatch between boxes and scores input");
74
75
76
77
78
79
80
81
82
83
        }

        std::vector<int64_t> out_lens(2);
        out_lens.at(0) = lens.at(1);
        out_lens.at(1) = 3;
        return {shape::int64_type, out_lens};
    }

    struct box
    {
84
85
        std::array<double, 2> x;
        std::array<double, 2> y;
86
87
88
89
90
91
92

        void sort()
        {
            std::sort(x.begin(), x.end());
            std::sort(y.begin(), y.end());
        }

93
        std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
94

95
        double area() const
96
97
98
99
100
101
102
103
        {
            assert(std::is_sorted(x.begin(), x.end()));
            assert(std::is_sorted(y.begin(), y.end()));
            return (x[1] - x[0]) * (y[1] - y[0]);
        }
    };

    template <class T>
104
    box batch_box(T boxes, std::size_t box_idx) const
105
106
    {
        box result{};
107
        auto start = boxes + 4 * box_idx;
108
109
        if(center_point_box)
        {
110
111
112
113
114
115
            double half_width  = start[2] / 2.0;
            double half_height = start[3] / 2.0;
            double x_center    = start[0];
            double y_center    = start[1];
            result.x           = {x_center - half_width, x_center + half_width};
            result.y           = {y_center - half_height, y_center + half_height};
116
117
118
        }
        else
        {
119
120
            result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
            result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
121
122
123
124
125
        }

        return result;
    }

126
    inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
127
128
129
130
131
132
133
134
135
136
137
    {
        b1.sort();
        b2.sort();

        box intersection{};
        for(auto i : range(2))
        {
            intersection[i][0] = std::max(b1[i][0], b2[i][0]);
            intersection[i][1] = std::min(b1[i][1], b2[i][1]);
        }

138
        std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
139
140
141
142
143
144
145
        if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
               return not std::is_sorted(bx.begin(), bx.end());
           }))
        {
            return false;
        }

146
147
148
149
        const double area1             = b1.area();
        const double area2             = b2.area();
        const double intersection_area = intersection.area();
        const double union_area        = area1 + area2 - intersection_area;
150
151
152
153
154
155

        if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
        {
            return false;
        }

156
        const double intersection_over_union = intersection_area / union_area;
157
158
159
160

        return intersection_over_union > iou_threshold;
    }

161
162
163
164
    // filter boxes below score_threshold
    template <class T>
    std::priority_queue<std::pair<double, int64_t>>
    filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
165
    {
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        std::priority_queue<std::pair<double, int64_t>> boxes_heap;
        auto insert_to_boxes_heap =
            make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
        int64_t box_idx = 0;
        transform_if(
            scores_start,
            scores_start + num_boxes,
            insert_to_boxes_heap,
            [&](auto sc) {
                box_idx++;
                return sc >= score_threshold;
            },
            [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
        return boxes_heap;
    }
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    template <class Output, class Boxes, class Scores>
    void compute_nms(Output output,
                     Boxes boxes,
                     Scores scores,
                     const shape& output_shape,
                     std::size_t max_output_boxes_per_class,
                     double iou_threshold,
                     double score_threshold) const
    {
        std::fill(output.begin(), output.end(), 0);
        const auto& lens       = scores.get_shape().lens();
        const auto num_batches = lens[0];
        const auto num_classes = lens[1];
        const auto num_boxes   = lens[2];
        // boxes of a class with NMS applied [score, index]
        std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
198
199
        std::vector<int64_t> selected_indices;
        selected_boxes_inside_class.reserve(output_shape.elements());
200
201
        // iterate over batches and classes
        shape comp_s{shape::double_type, {num_batches, num_classes}};
202
        shape_for_each(comp_s, [&](auto idx) {
203
204
205
206
207
208
209
            auto batch_idx = idx[0];
            auto class_idx = idx[1];
            // index offset for this class
            auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
            // iterator to first value of this batch
            auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
            auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
210
211
            selected_boxes_inside_class.clear();
            // Get the next box with top score, filter by iou_threshold
212
            while(!boxes_heap.empty() &&
213
214
                  selected_boxes_inside_class.size() < max_output_boxes_per_class)
            {
215
216
217
218
219
220
221
222
223
224
225
226
                // Check with existing selected boxes for this class, remove box if it
                // exceeds the IOU (Intersection Over Union) threshold
                const auto next_top_score = boxes_heap.top();
                bool not_selected =
                    std::any_of(selected_boxes_inside_class.begin(),
                                selected_boxes_inside_class.end(),
                                [&](auto selected_index) {
                                    return this->suppress_by_iou(
                                        batch_box(batch_boxes_start, next_top_score.second),
                                        batch_box(batch_boxes_start, selected_index.second),
                                        iou_threshold);
                                });
227
228
229
230

                if(not not_selected)
                {
                    selected_boxes_inside_class.push_back(next_top_score);
231
232
                    selected_indices.push_back(batch_idx);
                    selected_indices.push_back(class_idx);
233
234
                    selected_indices.push_back(next_top_score.second);
                }
235
                boxes_heap.pop();
236
237
            }
        });
238
239
240
241
242
243
        std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        std::size_t max_output_boxes_per_class =
            (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
        if(max_output_boxes_per_class == 0)
        {
            return result;
        }
        double iou_threshold   = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
        double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;

        result.visit([&](auto output) {
            visit_all(args[0], args[1])([&](auto boxes, auto scores) {
                compute_nms(output,
                            boxes,
                            scores,
                            output_shape,
                            max_output_boxes_per_class,
                            iou_threshold,
                            score_threshold);
            });
264
265
266
267
268
269
270
271
272
273
274
        });

        return result;
    }
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif