Unverified Commit e4975990 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Prune candidates in NMS (#1601)

* NMS improvements
parent d83b8397
...@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const ...@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const
"if", "if",
"loop", "loop",
"roialign", "roialign",
"nonmaxsuppression",
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none"};
......
...@@ -143,16 +143,22 @@ struct nonmaxsuppression ...@@ -143,16 +143,22 @@ struct nonmaxsuppression
void sort() void sort()
{ {
std::sort(x.begin(), x.end()); if(x[0] > x[1])
std::sort(y.begin(), y.end()); {
std::swap(x[0], x[1]);
}
if(y[0] > y[1])
{
std::swap(y[0], y[1]);
}
} }
std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; } std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
double area() const double area() const
{ {
assert(std::is_sorted(x.begin(), x.end())); assert(x[0] <= x[1]);
assert(std::is_sorted(y.begin(), y.end())); assert(y[0] <= y[1]);
return (x[1] - x[0]) * (y[1] - y[0]); return (x[1] - x[0]) * (y[1] - y[0]);
} }
}; };
...@@ -190,14 +196,10 @@ struct nonmaxsuppression ...@@ -190,14 +196,10 @@ struct nonmaxsuppression
{ {
intersection[i][0] = std::max(b1[i][0], b2[i][0]); intersection[i][0] = std::max(b1[i][0], b2[i][0]);
intersection[i][1] = std::min(b1[i][1], b2[i][1]); intersection[i][1] = std::min(b1[i][1], b2[i][1]);
} if(intersection[i][0] > intersection[i][1])
{
std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y}; return false;
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { }
return not std::is_sorted(bx.begin(), bx.end());
}))
{
return false;
} }
const double area1 = b1.area(); const double area1 = b1.area();
...@@ -265,31 +267,31 @@ struct nonmaxsuppression ...@@ -265,31 +267,31 @@ struct nonmaxsuppression
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4; auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold); auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(not boxes_heap.empty() && while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
// Check with existing selected boxes for this class, remove box if it // select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
// exceeds the IOU (Intersection Over Union) threshold // threshold with the selected box
const auto next_top_score = boxes_heap.top(); const auto next_top_score = boxes_heap.top();
bool not_selected = boxes_heap.pop();
std::any_of(selected_boxes_inside_class.begin(), selected_boxes_inside_class.push_back(next_top_score);
selected_boxes_inside_class.end(), selected_indices.push_back(batch_idx);
[&](auto selected_index) { selected_indices.push_back(class_idx);
return this->suppress_by_iou( selected_indices.push_back(next_top_score.second);
batch_box(batch_boxes_start, next_top_score.second), std::priority_queue<std::pair<double, int64_t>> remainder_boxes;
batch_box(batch_boxes_start, selected_index.second), while(not boxes_heap.empty())
iou_threshold);
});
if(not not_selected)
{ {
selected_boxes_inside_class.push_back(next_top_score); auto iou_candidate_box = boxes_heap.top();
selected_indices.push_back(batch_idx); if(not this->suppress_by_iou(
selected_indices.push_back(class_idx); batch_box(batch_boxes_start, iou_candidate_box.second),
selected_indices.push_back(next_top_score.second); batch_box(batch_boxes_start, next_top_score.second),
iou_threshold))
{
remainder_boxes.push(iou_candidate_box);
}
boxes_heap.pop();
} }
boxes_heap.pop(); boxes_heap = remainder_boxes;
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment