Commit 422d2c73 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Use int64_t to track selected_boxes_inside_class

This cleans up the compute_nms signature as well as stops using additional
memory by not storing every pair result twice that just gets cleared per run each shape_for_each()
parent 680ae7cc
......@@ -260,7 +260,6 @@ struct nonmaxsuppression
std::size_t compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& max_output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
......@@ -271,9 +270,7 @@ struct nonmaxsuppression
const auto num_classes = lens[1];
const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) {
......@@ -284,20 +281,21 @@ struct nonmaxsuppression
// iterator to first value of this batch
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear();
int64_t selected_boxes_inside_class = 0;
while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
selected_boxes_inside_class < max_output_boxes_per_class)
{
// select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
// threshold with the selected box
const auto next_top_score = boxes_heap.top();
auto next_box = batch_box(batch_boxes_start, next_top_score.second);
auto next_box_idx = next_top_score.second;
boxes_heap.pop();
selected_boxes_inside_class.push_back(next_top_score);
selected_boxes_inside_class++;
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
selected_indices.push_back(next_box_idx);
std::priority_queue<std::pair<double, int64_t>> remainder_boxes;
while(not boxes_heap.empty())
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment