Commit e4759983 authored by charlie's avatar charlie
Browse files

formatting

parent fbea17d7
...@@ -70,7 +70,8 @@ struct nonmaxsuppression ...@@ -70,7 +70,8 @@ struct nonmaxsuppression
}; };
template <class T> template <class T>
box batch_box(const migraphx::tensor_view<T>& boxes, std::size_t box_ind, std::size_t box_idx) const box
batch_box(const migraphx::tensor_view<T>& boxes, std::size_t box_ind, std::size_t box_idx) const
{ {
box result{}; box result{};
auto start = box_ind + 4 * box_idx; auto start = box_ind + 4 * box_idx;
...@@ -92,7 +93,6 @@ struct nonmaxsuppression ...@@ -92,7 +93,6 @@ struct nonmaxsuppression
return result; return result;
} }
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const
{ {
b1.sort(); b1.sort();
...@@ -128,23 +128,26 @@ struct nonmaxsuppression ...@@ -128,23 +128,26 @@ struct nonmaxsuppression
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::size_t max_output_boxes_per_class = (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; std::size_t max_output_boxes_per_class =
if(max_output_boxes_per_class == 0) { return result; } (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
float iou_threshold = (args.size() > 3) ? (args.at(3).at<float>()) : 0.0f; if(max_output_boxes_per_class == 0)
{
return result;
}
float iou_threshold = (args.size() > 3) ? (args.at(3).at<float>()) : 0.0f;
float score_threshold = (args.size() > 4) ? (args.at(4).at<float>()) : 0.0f; float score_threshold = (args.size() > 4) ? (args.at(4).at<float>()) : 0.0f;
result.visit([&](auto output){ result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores){ visit_all(args[0], args[1])([&](auto boxes, auto scores) {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0]; const auto num_batches = lens[0];
const auto num_classes = lens[1]; const auto num_classes = lens[1];
const auto num_boxes = boxes.get_shape().lens()[1]; const auto num_boxes = boxes.get_shape().lens()[1];
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class; std::vector<std::pair<float, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
...@@ -152,59 +155,60 @@ struct nonmaxsuppression ...@@ -152,59 +155,60 @@ struct nonmaxsuppression
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::float_type, {num_batches, num_classes}}; shape comp_s{shape::float_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto batch_idx = idx[0]; auto batch_idx = idx[0];
auto class_idx = idx[1]; auto class_idx = idx[1];
// index offset for this class // index offset for this class
std::size_t score_offset_ind = (batch_idx * num_classes + class_idx) * num_boxes; std::size_t score_offset_ind =
// index to first value of this batch (batch_idx * num_classes + class_idx) * num_boxes;
std::size_t batch_boxes_ind = batch_idx * num_boxes * 4; // index to first value of this batch
std::priority_queue<std::pair<float, int64_t>> boxes_heap; std::size_t batch_boxes_ind = batch_idx * num_boxes * 4;
auto insert_to_boxes_heap = make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); }); std::priority_queue<std::pair<float, int64_t>> boxes_heap;
int64_t box_idx = 0; auto insert_to_boxes_heap =
make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
// filter boxes below score_threshold int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset_ind, // filter boxes below score_threshold
scores.begin() + score_offset_ind + num_boxes, transform_if(
insert_to_boxes_heap, scores.begin() + score_offset_ind,
[&](auto sc) { scores.begin() + score_offset_ind + num_boxes,
box_idx++; insert_to_boxes_heap,
return sc >= score_threshold; [&](auto sc) {
}, box_idx++;
[&](auto sc) { return std::make_pair(sc, box_idx - 1); }); return sc >= score_threshold;
},
selected_boxes_inside_class.clear(); [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
// Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() && selected_boxes_inside_class.clear();
selected_boxes_inside_class.size() < max_output_boxes_per_class) // Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
const std::pair<float, int64_t>& next_top_score = boxes_heap.top();
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(boxes, batch_boxes_ind, next_top_score.second),
batch_box(boxes, batch_boxes_ind, selected_index.second),
iou_threshold);
});
if(not not_selected)
{ {
const std::pair<float, int64_t>& next_top_score = boxes_heap.top(); selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
// Check with existing selected boxes for this class, remove box if it exceeds the IOU selected_indices.push_back(class_idx);
// (Intersection Over Union) threshold selected_indices.push_back(next_top_score.second);
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(boxes, batch_boxes_ind, next_top_score.second),
batch_box(boxes, batch_boxes_ind, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
}
boxes_heap.pop();
} }
boxes_heap.pop();
}
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
} });
);
}); });
return result; return result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment