"script/profile_reduce_with_index.sh" did not exist on "12dfba3d03f402c051e2129fa21f33264f4d26e5"
Unverified Commit e4975990 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Prune candidates in NMS (#1601)

* NMS improvements
parent d83b8397
......@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const
"if",
"loop",
"roialign",
"nonmaxsuppression",
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
......
......@@ -143,16 +143,22 @@ struct nonmaxsuppression
void sort()
{
std::sort(x.begin(), x.end());
std::sort(y.begin(), y.end());
if(x[0] > x[1])
{
std::swap(x[0], x[1]);
}
if(y[0] > y[1])
{
std::swap(y[0], y[1]);
}
}
std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
double area() const
{
assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end()));
assert(x[0] <= x[1]);
assert(y[0] <= y[1]);
return (x[1] - x[0]) * (y[1] - y[0]);
}
};
......@@ -190,14 +196,10 @@ struct nonmaxsuppression
{
intersection[i][0] = std::max(b1[i][0], b2[i][0]);
intersection[i][1] = std::min(b1[i][1], b2[i][1]);
}
std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end());
}))
{
return false;
if(intersection[i][0] > intersection[i][1])
{
return false;
}
}
const double area1 = b1.area();
......@@ -265,31 +267,31 @@ struct nonmaxsuppression
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
// select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
// threshold with the selected box
const auto next_top_score = boxes_heap.top();
bool not_selected =
std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes_start, selected_index.second),
iou_threshold);
});
if(not not_selected)
boxes_heap.pop();
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
std::priority_queue<std::pair<double, int64_t>> remainder_boxes;
while(not boxes_heap.empty())
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
auto iou_candidate_box = boxes_heap.top();
if(not this->suppress_by_iou(
batch_box(batch_boxes_start, iou_candidate_box.second),
batch_box(batch_boxes_start, next_top_score.second),
iou_threshold))
{
remainder_boxes.push(iou_candidate_box);
}
boxes_heap.pop();
}
boxes_heap.pop();
boxes_heap = remainder_boxes;
}
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment