".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "ad1ae7f7cd0e735599eae7ca58266b17eb38b385"
Commit 8e622e9f authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King
Browse files

Ignore truth rects that overlap too much (have same index in feature coordinates) (#1896)

* Add test case that makes MMOD loss go negative with certain ignore-rect configuration

* Disregard duplicate truth boxes

* Minor optimization

* Remove possibly outdated comment

* Clarify the detection count test criterion a little

* Review fix

* Review fixes:
- for perf reasons, keep only the first rect for each truth idx
- fix warning message grammar
parent cff605ad
......@@ -13,6 +13,7 @@
#include "../svm/ranking_tools.h"
#include <sstream>
#include <map>
#include <unordered_map>
namespace dlib
{
......@@ -1142,6 +1143,8 @@ namespace dlib
std::vector<size_t> truth_idxs;
truth_idxs.reserve(truth->size());
std::unordered_map<size_t, rectangle> idx_to_truth_rect;
// The loss will measure the number of incorrect detections. A detection is
// incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection
// on a truth rectangle.
......@@ -1156,20 +1159,33 @@ namespace dlib
{
// Ignore boxes that can't be detected by the CNN.
loss -= options.loss_per_missed_target;
truth_idxs.push_back(0);
truth_idxs.push_back(-1);
continue;
}
const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
const auto i = idx_to_truth_rect.find(idx);
if (i != idx_to_truth_rect.end())
{
// Ignore duplicate truth box in feature coordinates.
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << x.rect;
std::cout << ", and we are ignoring it because it maps to the exact same feature coordinates ";
std::cout << "as another truth rectangle located at " << i->second << "." << std::endl;
loss -= options.loss_per_missed_target;
truth_idxs.push_back(-1);
continue;
}
loss -= out_data[idx];
// compute gradient
g[idx] = -scale;
truth_idxs.push_back(idx);
idx_to_truth_rect[idx] = x.rect;
}
else
{
// This box was ignored so shouldn't have been counted in the loss.
loss -= options.loss_per_missed_target;
truth_idxs.push_back(0);
truth_idxs.push_back(-1);
}
}
......@@ -1226,16 +1242,19 @@ namespace dlib
if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
{
const size_t idx = truth_idxs[i];
// We are ignoring this box so we shouldn't have counted it in the
// loss in the first place. So we subtract out the loss values we
// added for it in the code above.
loss -= options.loss_per_missed_target-out_data[idx];
g[idx] = 0;
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
std::cout << " that is suppressed by non-max-suppression ";
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box
<< " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:"
<< box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl;
if (idx != -1)
{
// We are ignoring this box so we shouldn't have counted it in the
// loss in the first place. So we subtract out the loss values we
// added for it in the code above.
loss -= options.loss_per_missed_target-out_data[idx];
g[idx] = 0;
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
std::cout << " that is suppressed by non-max-suppression ";
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box
<< " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:"
<< box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl;
}
}
}
}
......
......@@ -3231,6 +3231,87 @@ namespace
}
// ----------------------------------------------------------------------------------------
void test_loss_mmod()
{
print_spinner();
// Define input image size.
constexpr int nc = 20;
constexpr int nr = 20;
constexpr int margin = 3;
// Create a checkerboard pattern.
std::deque<point> labeled_points;
for (int y = margin; y < nr - margin; ++y)
for (int x = margin + 1 - y % 2; x < nc - margin; x += 2)
labeled_points.emplace_back(x, y);
// Create training data that follows the generated pattern.
typedef matrix<float> input_image_type;
const auto generate_input_image = [&labeled_points, nr, nc]()
{
input_image_type sample(nr, nc);
sample = -1.0;
for (const auto& point : labeled_points)
sample(point.y(), point.x()) = 1.0;
return sample;
};
const auto generate_labels = [&labeled_points]()
{
const auto point_to_rect = [](const point& point) {
constexpr int rect_size = 5;
return centered_rect(
point.x(), point.y(),
rect_size, rect_size
);
};
std::vector<mmod_rect> labels;
std::transform(
labeled_points.begin(),
labeled_points.end(),
std::back_inserter(labels),
point_to_rect
);
return labels;
};
const input_image_type input_image = generate_input_image();
const std::vector<mmod_rect> labels = generate_labels();
mmod_options options(use_image_pyramid::no, { labels });
// Define a simple network.
using net_type = loss_mmod<con<1,5,5,1,1,con<1,5,5,2,2,input<input_image_type>>>>;
net_type net(options);
dnn_trainer<net_type> trainer(net, sgd(0.1));
// Train the network. The loss is not supposed to go negative.
for (int i = 0; i < 100; ++i) {
print_spinner();
trainer.train_one_step({ input_image }, { labels });
DLIB_TEST(trainer.get_average_loss() >= 0.0);
}
// Inference should return something for the training data.
const auto dets = net(input_image);
DLIB_TEST(dets.size() > 0);
// Indeed many truth objects should be found.
const auto approximate_desired_det_count = (nr - 2 * margin) * (nc - 2 * margin) / 2.0;
DLIB_TEST(dets.size() > approximate_desired_det_count * 0.45);
DLIB_TEST(dets.size() < approximate_desired_det_count * 1.05);
}
// ----------------------------------------------------------------------------------------
class dnn_tester : public tester
......@@ -3321,6 +3402,7 @@ namespace
test_serialization();
test_loss_dot();
test_loss_multimulticlass_log();
test_loss_mmod();
}
void perform_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment