Commit 390b87ae authored by charlie's avatar charlie
Browse files

Formatting

parent 2c1cdd15
...@@ -111,8 +111,8 @@ struct nonmaxsuppression ...@@ -111,8 +111,8 @@ struct nonmaxsuppression
double half_height = start[3] / 2.0; double half_height = start[3] / 2.0;
double x_center = start[0]; double x_center = start[0];
double y_center = start[1]; double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width}; result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height}; result.y = {y_center - half_height, y_center + half_height};
} }
else else
{ {
...@@ -157,11 +157,11 @@ struct nonmaxsuppression ...@@ -157,11 +157,11 @@ struct nonmaxsuppression
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
// filter boxes below score_threshold // filter boxes below score_threshold
template <class T> template <class T>
std::priority_queue<std::pair<double, int64_t>> filter_boxes_by_score( std::priority_queue<std::pair<double, int64_t>>
T scores_start, std::size_t num_boxes, double score_threshold) const filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{ {
std::priority_queue<std::pair<double, int64_t>> boxes_heap; std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap = auto insert_to_boxes_heap =
...@@ -178,36 +178,34 @@ struct nonmaxsuppression ...@@ -178,36 +178,34 @@ struct nonmaxsuppression
[&](auto sc) { return std::make_pair(sc, box_idx - 1); }); [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
return boxes_heap; return boxes_heap;
} }
template <class H, class S> template <class H, class S>
void select_boxes( void select_boxes(H& boxes_heap,
H& boxes_heap, std::vector<std::pair<double, int64_t>>& selected_boxes_inside_class,
std::vector<std::pair<double, int64_t>>& selected_boxes_inside_class, std::vector<int64_t>& selected_indices,
std::vector<int64_t>& selected_indices, S batch_boxes_start,
S batch_boxes_start, std::size_t max_output_boxes_per_class,
std::size_t max_output_boxes_per_class, double iou_threshold,
double iou_threshold, std::size_t batch_idx,
std::size_t batch_idx, std::size_t class_idx) const
std::size_t class_idx
) const
{ {
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() && while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
// Check with existing selected boxes for this class, remove box if it // Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold // exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top(); const auto next_top_score = boxes_heap.top();
bool not_selected = std::any_of( bool not_selected =
selected_boxes_inside_class.begin(), std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(), selected_boxes_inside_class.end(),
[&](auto selected_index) { [&](auto selected_index) {
return this->suppress_by_iou( return this->suppress_by_iou(
batch_box(batch_boxes_start, next_top_score.second), batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes_start, selected_index.second), batch_box(batch_boxes_start, selected_index.second),
iou_threshold); iou_threshold);
}); });
if(not not_selected) if(not not_selected)
{ {
...@@ -219,7 +217,7 @@ struct nonmaxsuppression ...@@ -219,7 +217,7 @@ struct nonmaxsuppression
boxes_heap.pop(); boxes_heap.pop();
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -235,29 +233,37 @@ struct nonmaxsuppression ...@@ -235,29 +233,37 @@ struct nonmaxsuppression
result.visit([&](auto output) { result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) { visit_all(args[0], args[1])([&](auto boxes, auto scores) {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0]; const auto num_batches = lens[0];
const auto num_classes = lens[1]; const auto num_classes = lens[1];
const auto num_boxes = lens[2]; const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class; std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto batch_idx = idx[0]; auto batch_idx = idx[0];
auto class_idx = idx[1]; auto class_idx = idx[1];
// index offset for this class // index offset for this class
auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes; auto scores_start =
// iterator to first value of this batch scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4; // iterator to first value of this batch
auto boxes_heap = auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
filter_boxes_by_score(scores_start, num_boxes, score_threshold); auto boxes_heap =
select_boxes(boxes_heap, selected_boxes_inside_class, selected_indices, batch_boxes_start, max_output_boxes_per_class, iou_threshold, batch_idx, class_idx); filter_boxes_by_score(scores_start, num_boxes, score_threshold);
}); select_boxes(boxes_heap,
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); selected_boxes_inside_class,
selected_indices,
batch_boxes_start,
max_output_boxes_per_class,
iou_threshold,
batch_idx,
class_idx);
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment