Commit 2c1cdd15 authored by charlie's avatar charlie
Browse files

Use iterators and fix basic_iota_iterator

parent 80f7c2b7
...@@ -81,8 +81,9 @@ struct basic_iota_iterator ...@@ -81,8 +81,9 @@ struct basic_iota_iterator
index--; index--;
return it; return it;
} }
// TODO: operator->
reference operator*() const { return f(index); } reference operator*() const { return f(index); }
pointer operator->() const { return &f(index); }
reference operator[](int n) const { return f(index + n); }
}; };
template <class T, class F> template <class T, class F>
......
...@@ -81,8 +81,8 @@ struct nonmaxsuppression ...@@ -81,8 +81,8 @@ struct nonmaxsuppression
struct box struct box
{ {
std::array<float, 2> x; std::array<double, 2> x;
std::array<float, 2> y; std::array<double, 2> y;
void sort() void sort()
{ {
...@@ -90,9 +90,9 @@ struct nonmaxsuppression ...@@ -90,9 +90,9 @@ struct nonmaxsuppression
std::sort(y.begin(), y.end()); std::sort(y.begin(), y.end());
} }
std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; } std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
float area() const double area() const
{ {
assert(std::is_sorted(x.begin(), x.end())); assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end())); assert(std::is_sorted(y.begin(), y.end()));
...@@ -101,29 +101,29 @@ struct nonmaxsuppression ...@@ -101,29 +101,29 @@ struct nonmaxsuppression
}; };
template <class T> template <class T>
box batch_box(const T& boxes, std::size_t box_ind, std::size_t box_idx) const box batch_box(T boxes, std::size_t box_idx) const
{ {
box result{}; box result{};
auto start = box_ind + 4 * box_idx; auto start = boxes + 4 * box_idx;
if(center_point_box) if(center_point_box)
{ {
float half_width = boxes[start + 2] / 2.0; double half_width = start[2] / 2.0;
float half_height = boxes[start + 3] / 2.0; double half_height = start[3] / 2.0;
float x_center = boxes[start + 0]; double x_center = start[0];
float y_center = boxes[start + 1]; double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width}; result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height}; result.y = {y_center - half_height, y_center + half_height};
} }
else else
{ {
result.x = {static_cast<float>(boxes[start + 1]), static_cast<float>(boxes[start + 3])}; result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
result.y = {static_cast<float>(boxes[start + 0]), static_cast<float>(boxes[start + 2])}; result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
} }
return result; return result;
} }
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
{ {
b1.sort(); b1.sort();
b2.sort(); b2.sort();
...@@ -135,7 +135,7 @@ struct nonmaxsuppression ...@@ -135,7 +135,7 @@ struct nonmaxsuppression
intersection[i][1] = std::min(b1[i][1], b2[i][1]); intersection[i][1] = std::min(b1[i][1], b2[i][1]);
} }
std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y}; std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end()); return not std::is_sorted(bx.begin(), bx.end());
})) }))
...@@ -143,33 +143,33 @@ struct nonmaxsuppression ...@@ -143,33 +143,33 @@ struct nonmaxsuppression
return false; return false;
} }
const float area1 = b1.area(); const double area1 = b1.area();
const float area2 = b2.area(); const double area2 = b2.area();
const float intersection_area = intersection.area(); const double intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area; const double union_area = area1 + area2 - intersection_area;
if(area1 <= .0f or area2 <= .0f or union_area <= .0f) if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{ {
return false; return false;
} }
const float intersection_over_union = intersection_area / union_area; const double intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
// filter boxes below score_threshold // filter boxes below score_threshold
template <class T> template <class T>
std::priority_queue<std::pair<float, int64_t>> filter_boxes_by_score( std::priority_queue<std::pair<double, int64_t>> filter_boxes_by_score(
T scores, std::size_t score_offset_ind, std::size_t num_boxes, float score_threshold) const T scores_start, std::size_t num_boxes, double score_threshold) const
{ {
std::priority_queue<std::pair<float, int64_t>> boxes_heap; std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap = auto insert_to_boxes_heap =
make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); }); make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0; int64_t box_idx = 0;
transform_if( transform_if(
scores.begin() + score_offset_ind, scores_start,
scores.begin() + score_offset_ind + num_boxes, scores_start + num_boxes,
insert_to_boxes_heap, insert_to_boxes_heap,
[&](auto sc) { [&](auto sc) {
box_idx++; box_idx++;
...@@ -178,7 +178,48 @@ struct nonmaxsuppression ...@@ -178,7 +178,48 @@ struct nonmaxsuppression
[&](auto sc) { return std::make_pair(sc, box_idx - 1); }); [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
return boxes_heap; return boxes_heap;
} }
template <class H, class S>
void select_boxes(
H& boxes_heap,
std::vector<std::pair<double, int64_t>>& selected_boxes_inside_class,
std::vector<int64_t>& selected_indices,
S batch_boxes_start,
std::size_t max_output_boxes_per_class,
double iou_threshold,
std::size_t batch_idx,
std::size_t class_idx
) const
{
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top();
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes_start, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
}
boxes_heap.pop();
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -189,61 +230,34 @@ struct nonmaxsuppression ...@@ -189,61 +230,34 @@ struct nonmaxsuppression
{ {
return result; return result;
} }
float iou_threshold = (args.size() > 3) ? (args.at(3).at<float>()) : 0.0f; double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
float score_threshold = (args.size() > 4) ? (args.at(4).at<float>()) : 0.0f; double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
result.visit([&](auto output) { result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) { visit_all(args[0], args[1])([&](auto boxes, auto scores) {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0]; const auto num_batches = lens[0];
const auto num_classes = lens[1]; const auto num_classes = lens[1];
const auto num_boxes = lens[2]; const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class; std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::float_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto batch_idx = idx[0]; auto batch_idx = idx[0];
auto class_idx = idx[1]; auto class_idx = idx[1];
// index offset for this class // index offset for this class
std::size_t score_offset_ind = auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
(batch_idx * num_classes + class_idx) * num_boxes; // iterator to first value of this batch
// index to first value of this batch auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
std::size_t batch_boxes_ind = batch_idx * num_boxes * 4; auto boxes_heap =
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
filter_boxes_by_score(scores, score_offset_ind, num_boxes, score_threshold); select_boxes(boxes_heap, selected_boxes_inside_class, selected_indices, batch_boxes_start, max_output_boxes_per_class, iou_threshold, batch_idx, class_idx);
selected_boxes_inside_class.clear(); });
// Get the next box with top score, filter by iou_threshold std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top();
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(boxes, batch_boxes_ind, next_top_score.second),
batch_box(boxes, batch_boxes_ind, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
}
boxes_heap.pop();
}
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment