Commit c9edeb6d authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

backup debug stuff from nms

parent 838ce88f
...@@ -197,6 +197,7 @@ struct nonmaxsuppression ...@@ -197,6 +197,7 @@ struct nonmaxsuppression
return not std::is_sorted(bx.begin(), bx.end()); return not std::is_sorted(bx.begin(), bx.end());
})) }))
{ {
// std::cout << "False-NoOverlap" << std::endl;
return false; return false;
} }
...@@ -207,11 +208,15 @@ struct nonmaxsuppression ...@@ -207,11 +208,15 @@ struct nonmaxsuppression
if(area1 <= .0f or area2 <= .0f or union_area <= .0f) if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{ {
// std::cout << "False" << std::endl;
// std::cout << "A1=" << area1 << " A2=" << area2 << " UnionArea=" << union_area <<
// std::endl;
return false; return false;
} }
const double intersection_over_union = intersection_area / union_area; const double intersection_over_union = intersection_area / union_area;
// std::cout << "Checking IOU > IOU_THRESH" << (intersection_over_union > iou_threshold) <<
// std::endl;
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
...@@ -247,9 +252,17 @@ struct nonmaxsuppression ...@@ -247,9 +252,17 @@ struct nonmaxsuppression
{ {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
const auto& box_lens = boxes.get_shape().lens();
const auto num_batches = lens[0]; const auto num_batches = lens[0];
const auto num_classes = lens[1]; const auto num_classes = lens[1];
const auto num_boxes = lens[2]; const auto num_boxes = lens[2];
std::cout << "num_batches=" << num_batches << std::endl;
std::cout << "num_classes=" << num_classes << std::endl;
std::cout << "num_boxes=" << num_boxes << std::endl;
std::cout << "box dims" << box_lens[0] << " " << box_lens[1] << " " << box_lens[2]
<< std::endl;
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class; std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
...@@ -259,6 +272,8 @@ struct nonmaxsuppression ...@@ -259,6 +272,8 @@ struct nonmaxsuppression
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto batch_idx = idx[0]; auto batch_idx = idx[0];
auto class_idx = idx[1]; auto class_idx = idx[1];
std::cout << "Batch_idx=" << batch_idx << std::endl;
std::cout << "class_idx=" << class_idx << std::endl;
// index offset for this class // index offset for this class
auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes; auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
// iterator to first value of this batch // iterator to first value of this batch
...@@ -266,12 +281,16 @@ struct nonmaxsuppression ...@@ -266,12 +281,16 @@ struct nonmaxsuppression
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold); auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(not boxes_heap.empty() && while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
// std::cout << "heap size=" << boxes_heap.size() << std::endl;
// Check with existing selected boxes for this class, remove box if it // Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold // exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top(); const auto next_top_score = boxes_heap.top();
// std::cout << "top_score=" << next_top_score.second << std::endl;
bool not_selected = bool not_selected =
std::any_of(selected_boxes_inside_class.begin(), std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(), selected_boxes_inside_class.end(),
...@@ -288,10 +307,14 @@ struct nonmaxsuppression ...@@ -288,10 +307,14 @@ struct nonmaxsuppression
selected_indices.push_back(batch_idx); selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx); selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second); selected_indices.push_back(next_top_score.second);
std::cout << "Not selected seleted_boxxes_inside_size="
<< selected_boxes_inside_class.size() << std::endl;
} }
boxes_heap.pop(); boxes_heap.pop();
// std::cout << "Pop heap size=" << boxes_heap.size() << std::endl;
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
return selected_indices.size() / 3; return selected_indices.size() / 3;
} }
...@@ -302,16 +325,36 @@ struct nonmaxsuppression ...@@ -302,16 +325,36 @@ struct nonmaxsuppression
shape max_output_shape = {output_shape.type(), output_shape.max_lens()}; shape max_output_shape = {output_shape.type(), output_shape.max_lens()};
argument result{max_output_shape}; argument result{max_output_shape};
std::cout << args.size() << std::endl;
std::size_t max_output_boxes_per_class = std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
std::cout << "Max output boxes per class=" << max_output_boxes_per_class << std::endl;
if(max_output_boxes_per_class == 0) if(max_output_boxes_per_class == 0)
{ {
return result; return result;
} }
//#bound the max amount of boxes allowed to the max of boxes
if(max_output_boxes_per_class > args[1].get_shape().lens()[2])
{
max_output_boxes_per_class = args[1].get_shape().lens()[2];
std::cout << "update Max output boxes per class=" << max_output_boxes_per_class
<< std::endl;
}
std::cout << "Args@2=" << args.at(2) << std::endl;
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f; double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f; double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
std::size_t num_selected = 0; std::size_t num_selected = 0;
// std::cout << "boxes=" << args[0].elements() << " scores=" << args[1].elements() <<
// std::endl;
std::cout << "IOU_TH=" << iou_threshold << " SCORE_TH=" << score_threshold << std::endl;
result.visit([&](auto output) { result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) { visit_all(args[0], args[1])([&](auto boxes, auto scores) {
num_selected = compute_nms(output, num_selected = compute_nms(output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment