Commit 5ec4b513 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Use copy_if and parallel execution by leveraging TBB.

- Add support for TBB in MIGraphX
- Add include for TBB in DockerFile
- Replace inner loop with copy_if and use std::execution:par to filter
- Change heap to vector and sort in parallel in filter_boxes_per_score()

With the help of Paul this cuts down NMS in ref from around 43-44s to about 2s
parent 422d2c73
...@@ -50,6 +50,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -50,6 +50,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
hipify-clang \ hipify-clang \
half \ half \
libssl-dev \ libssl-dev \
libtbb-dev \
zlib1g-dev && \ zlib1g-dev && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
......
...@@ -247,6 +247,9 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -247,6 +247,9 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
find_package(Threads) find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(TBB REQUIRED)
target_link_libraries(migraphx PUBLIC TBB::tbb)
find_package(nlohmann_json 3.8.0 REQUIRED) find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <queue> #include <queue>
#include <cstdint> #include <cstdint>
#include <iterator> #include <iterator>
#include <execution>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
...@@ -225,12 +226,12 @@ struct nonmaxsuppression ...@@ -225,12 +226,12 @@ struct nonmaxsuppression
// filter boxes below score_threshold // filter boxes below score_threshold
template <class T> template <class T>
std::priority_queue<std::pair<double, int64_t>> std::vector<std::pair<double, int64_t>>
filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{ {
std::priority_queue<std::pair<double, int64_t>> boxes_heap; std::vector<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap = auto insert_to_boxes_heap =
make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); }); make_function_output_iterator([&](const auto& x) { boxes_heap.push_back(x); });
int64_t box_idx = 0; int64_t box_idx = 0;
if(score_threshold > 0.0) if(score_threshold > 0.0)
...@@ -253,6 +254,7 @@ struct nonmaxsuppression ...@@ -253,6 +254,7 @@ struct nonmaxsuppression
return std::make_pair(sc, box_idx - 1); return std::make_pair(sc, box_idx - 1);
}); });
} }
std::sort(std::execution::par, boxes_heap.begin(), boxes_heap.end());
return boxes_heap; return boxes_heap;
} }
...@@ -287,28 +289,32 @@ struct nonmaxsuppression ...@@ -287,28 +289,32 @@ struct nonmaxsuppression
{ {
// select next top scorer box and remove any boxes from boxes_heap that exceeds IOU // select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
// threshold with the selected box // threshold with the selected box
const auto next_top_score = boxes_heap.top(); const auto next_top_score = boxes_heap.front();
auto next_box = batch_box(batch_boxes_start, next_top_score.second); auto next_box = batch_box(batch_boxes_start, next_top_score.second);
auto next_box_idx = next_top_score.second; auto next_box_idx = next_top_score.second;
boxes_heap.pop(); // Poor man's "pop" for vector
boxes_heap.erase(boxes_heap.begin());
selected_boxes_inside_class++; selected_boxes_inside_class++;
selected_indices.push_back(batch_idx); selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx); selected_indices.push_back(class_idx);
selected_indices.push_back(next_box_idx); selected_indices.push_back(next_box_idx);
std::priority_queue<std::pair<double, int64_t>> remainder_boxes;
while(not boxes_heap.empty()) std::vector<std::pair<double, int64_t>> remainder_boxes(boxes_heap.size());
{
auto iou_candidate_box = boxes_heap.top(); auto it =
auto iou_candidate = batch_box(batch_boxes_start, iou_candidate_box.second); std::copy_if(std::execution::par,
auto suppress_box = boxes_heap.begin(),
this->suppress_by_iou(iou_candidate, next_box, iou_threshold); boxes_heap.end(),
if(not suppress_box) remainder_boxes.begin(),
{ [&](auto iou_candidate_box) {
remainder_boxes.push(iou_candidate_box); auto iou_box =
} batch_box(batch_boxes_start, iou_candidate_box.second);
boxes_heap.pop(); return not this->suppress_by_iou(
} std::ref(iou_box), std::ref(next_box), iou_threshold);
});
remainder_boxes.resize(it - remainder_boxes.begin());
boxes_heap = remainder_boxes; boxes_heap = remainder_boxes;
} }
}); });
...@@ -337,7 +343,6 @@ struct nonmaxsuppression ...@@ -337,7 +343,6 @@ struct nonmaxsuppression
num_selected = compute_nms(output, num_selected = compute_nms(output,
boxes, boxes,
scores, scores,
max_output_shape,
max_output_boxes_per_class, max_output_boxes_per_class,
iou_threshold, iou_threshold,
score_threshold); score_threshold);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment