Commit 571a3464 authored by charlie's avatar charlie
Browse files

Tidy complexity fix

parent b2efe895
...@@ -70,8 +70,7 @@ struct nonmaxsuppression ...@@ -70,8 +70,7 @@ struct nonmaxsuppression
}; };
template <class T> template <class T>
box box batch_box(const T& boxes, std::size_t box_ind, std::size_t box_idx) const
batch_box(const migraphx::tensor_view<T>& boxes, std::size_t box_ind, std::size_t box_idx) const
{ {
box result{}; box result{};
auto start = box_ind + 4 * box_idx; auto start = box_ind + 4 * box_idx;
...@@ -128,6 +127,29 @@ struct nonmaxsuppression ...@@ -128,6 +127,29 @@ struct nonmaxsuppression
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
// filter boxes below score_threshold
template <class T>
void filter_boxes_by_score(
T scores,
std::size_t score_offset_ind,
std::size_t num_boxes,
const float score_threshold,
std::priority_queue<std::pair<float, int64_t>>& boxes_heap) const
{
auto insert_to_boxes_heap =
make_function_output_iterator([&boxes_heap](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset_ind,
scores.begin() + score_offset_ind + num_boxes,
insert_to_boxes_heap,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
}
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -163,30 +185,15 @@ struct nonmaxsuppression ...@@ -163,30 +185,15 @@ struct nonmaxsuppression
// index to first value of this batch // index to first value of this batch
std::size_t batch_boxes_ind = batch_idx * num_boxes * 4; std::size_t batch_boxes_ind = batch_idx * num_boxes * 4;
std::priority_queue<std::pair<float, int64_t>> boxes_heap; std::priority_queue<std::pair<float, int64_t>> boxes_heap;
auto insert_to_boxes_heap = filter_boxes_by_score(scores, score_offset_ind, num_boxes, score_threshold, boxes_heap);
make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
// filter boxes below score_threshold
transform_if(
scores.begin() + score_offset_ind,
scores.begin() + score_offset_ind + num_boxes,
insert_to_boxes_heap,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() && while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
const std::pair<float, int64_t>& next_top_score = boxes_heap.top();
// Check with existing selected boxes for this class, remove box if it // Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold // exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top();
bool not_selected = std::any_of( bool not_selected = std::any_of(
selected_boxes_inside_class.begin(), selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(), selected_boxes_inside_class.end(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment