Commit fbea17d7 authored by charlie's avatar charlie
Browse files

NMS refactor and nonstd shape

parent c650e2a4
......@@ -33,7 +33,6 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const
{
// requires at least 2 inputs
check_shapes{inputs, *this};
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens();
......@@ -71,28 +70,29 @@ struct nonmaxsuppression
};
template <class T>
box batch_box(const T* boxes, std::size_t bidx) const
box batch_box(const migraphx::tensor_view<T>& boxes, std::size_t box_ind, std::size_t box_idx) const
{
box result{};
const T* start = boxes + 4 * bidx;
auto start = box_ind + 4 * box_idx;
if(center_point_box)
{
float half_width = start[2] / 2.0f;
float half_height = start[3] / 2.0f;
float x_center = start[0];
float y_center = start[1];
float half_width = boxes[start + 2] / 2.0;
float half_height = boxes[start + 3] / 2.0;
float x_center = boxes[start + 0];
float y_center = boxes[start + 1];
result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height};
}
else
{
result.x = {start[1], start[3]};
result.y = {start[0], start[2]};
result.x = {static_cast<float>(boxes[start + 1]), static_cast<float>(boxes[start + 3])};
result.y = {static_cast<float>(boxes[start + 0]), static_cast<float>(boxes[start + 2])};
}
return result;
}
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const
{
b1.sort();
......@@ -128,100 +128,83 @@ struct nonmaxsuppression
return intersection_over_union > iou_threshold;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); });
std::size_t max_output_boxes_per_class = 0;
float iou_threshold = 0.0f;
float score_threshold = 0.0f;
if(args.size() > 2)
{
max_output_boxes_per_class = args.at(2).at<std::size_t>();
}
// max_output_boxes_per_class is 0, no output
if(max_output_boxes_per_class == 0)
{
return result;
}
if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}
if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}
const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());
auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>());
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0];
auto cidx = idx[1];
std::size_t score_offset = (bidx * class_num + cidx) * box_num;
const float* batch_boxes = boxes + bidx * box_num * 4;
std::priority_queue<std::pair<float, int64_t>> sorted_boxes;
auto insert_to_sorted_boxes =
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top();
// Check with existing selected boxes for this class, suppress if exceed the IOU
// (Intersection Over Union) threshold
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second),
batch_box(batch_boxes, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx);
selected_indices.push_back(cidx);
selected_indices.push_back(next_top_score.second);
}
sorted_boxes.pop();
std::size_t max_output_boxes_per_class = (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
if(max_output_boxes_per_class == 0) { return result; }
float iou_threshold = (args.size() > 3) ? (args.at(3).at<float>()) : 0.0f;
float score_threshold = (args.size() > 4) ? (args.at(4).at<float>()) : 0.0f;
result.visit([&](auto output){
visit_all(args[0], args[1])([&](auto boxes, auto scores){
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0];
const auto num_classes = lens[1];
const auto num_boxes = boxes.get_shape().lens()[1];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());
// iterate over batches and classes
shape comp_s{shape::float_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) {
auto batch_idx = idx[0];
auto class_idx = idx[1];
// index offset for this class
std::size_t score_offset_ind = (batch_idx * num_classes + class_idx) * num_boxes;
// index to first value of this batch
std::size_t batch_boxes_ind = batch_idx * num_boxes * 4;
std::priority_queue<std::pair<float, int64_t>> boxes_heap;
auto insert_to_boxes_heap = make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
// filter boxes below score_threshold
transform_if(
scores.begin() + score_offset_ind,
scores.begin() + score_offset_ind + num_boxes,
insert_to_boxes_heap,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
const std::pair<float, int64_t>& next_top_score = boxes_heap.top();
// Check with existing selected boxes for this class, remove box if it exceeds the IOU
// (Intersection Over Union) threshold
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(boxes, batch_boxes_ind, next_top_score.second),
batch_box(boxes, batch_boxes_ind, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
}
boxes_heap.pop();
}
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}
});
result.visit([&](auto out) {
std::copy(selected_indices.begin(), selected_indices.end(), out.begin());
);
});
return result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment