"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d07f73003d4d077854869b8f73275657f280334c"
Commit 8e622e9f authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King
Browse files

Ignore truth rects that overlap too much (have same index in feature coordinates) (#1896)

* Add test case that makes MMOD loss go negative with certain ignore-rect configuration

* Disregard duplicate truth boxes

* Minor optimization

* Remove possibly outdated comment

* Clarify the detection count test criterion a little

* Review fix

* Review fixes:
- for perf reasons, keep only the first rect for each truth idx
- fix warning message grammar
parent cff605ad
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "../svm/ranking_tools.h" #include "../svm/ranking_tools.h"
#include <sstream> #include <sstream>
#include <map> #include <map>
#include <unordered_map>
namespace dlib namespace dlib
{ {
...@@ -1142,6 +1143,8 @@ namespace dlib ...@@ -1142,6 +1143,8 @@ namespace dlib
std::vector<size_t> truth_idxs; std::vector<size_t> truth_idxs;
truth_idxs.reserve(truth->size()); truth_idxs.reserve(truth->size());
std::unordered_map<size_t, rectangle> idx_to_truth_rect;
// The loss will measure the number of incorrect detections. A detection is // The loss will measure the number of incorrect detections. A detection is
// incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection
// on a truth rectangle. // on a truth rectangle.
...@@ -1156,20 +1159,33 @@ namespace dlib ...@@ -1156,20 +1159,33 @@ namespace dlib
{ {
// Ignore boxes that can't be detected by the CNN. // Ignore boxes that can't be detected by the CNN.
loss -= options.loss_per_missed_target; loss -= options.loss_per_missed_target;
truth_idxs.push_back(0); truth_idxs.push_back(-1);
continue; continue;
} }
const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x(); const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
const auto i = idx_to_truth_rect.find(idx);
if (i != idx_to_truth_rect.end())
{
// Ignore duplicate truth box in feature coordinates.
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << x.rect;
std::cout << ", and we are ignoring it because it maps to the exact same feature coordinates ";
std::cout << "as another truth rectangle located at " << i->second << "." << std::endl;
loss -= options.loss_per_missed_target;
truth_idxs.push_back(-1);
continue;
}
loss -= out_data[idx]; loss -= out_data[idx];
// compute gradient // compute gradient
g[idx] = -scale; g[idx] = -scale;
truth_idxs.push_back(idx); truth_idxs.push_back(idx);
idx_to_truth_rect[idx] = x.rect;
} }
else else
{ {
// This box was ignored so shouldn't have been counted in the loss. // This box was ignored so shouldn't have been counted in the loss.
loss -= options.loss_per_missed_target; loss -= options.loss_per_missed_target;
truth_idxs.push_back(0); truth_idxs.push_back(-1);
} }
} }
...@@ -1226,16 +1242,19 @@ namespace dlib ...@@ -1226,16 +1242,19 @@ namespace dlib
if (options.overlaps_nms(best_matching_truth_box, (*truth)[i])) if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
{ {
const size_t idx = truth_idxs[i]; const size_t idx = truth_idxs[i];
// We are ignoring this box so we shouldn't have counted it in the if (idx != -1)
// loss in the first place. So we subtract out the loss values we {
// added for it in the code above. // We are ignoring this box so we shouldn't have counted it in the
loss -= options.loss_per_missed_target-out_data[idx]; // loss in the first place. So we subtract out the loss values we
g[idx] = 0; // added for it in the code above.
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect; loss -= options.loss_per_missed_target-out_data[idx];
std::cout << " that is suppressed by non-max-suppression "; g[idx] = 0;
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
<< " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:" std::cout << " that is suppressed by non-max-suppression ";
<< box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl; std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box
<< " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:"
<< box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl;
}
} }
} }
} }
......
...@@ -3231,6 +3231,87 @@ namespace ...@@ -3231,6 +3231,87 @@ namespace
} }
// ----------------------------------------------------------------------------------------
void test_loss_mmod()
{
print_spinner();
// Define input image size.
constexpr int nc = 20;
constexpr int nr = 20;
constexpr int margin = 3;
// Create a checkerboard pattern.
std::deque<point> labeled_points;
for (int y = margin; y < nr - margin; ++y)
for (int x = margin + 1 - y % 2; x < nc - margin; x += 2)
labeled_points.emplace_back(x, y);
// Create training data that follows the generated pattern.
typedef matrix<float> input_image_type;
const auto generate_input_image = [&labeled_points, nr, nc]()
{
input_image_type sample(nr, nc);
sample = -1.0;
for (const auto& point : labeled_points)
sample(point.y(), point.x()) = 1.0;
return sample;
};
const auto generate_labels = [&labeled_points]()
{
const auto point_to_rect = [](const point& point) {
constexpr int rect_size = 5;
return centered_rect(
point.x(), point.y(),
rect_size, rect_size
);
};
std::vector<mmod_rect> labels;
std::transform(
labeled_points.begin(),
labeled_points.end(),
std::back_inserter(labels),
point_to_rect
);
return labels;
};
const input_image_type input_image = generate_input_image();
const std::vector<mmod_rect> labels = generate_labels();
mmod_options options(use_image_pyramid::no, { labels });
// Define a simple network.
using net_type = loss_mmod<con<1,5,5,1,1,con<1,5,5,2,2,input<input_image_type>>>>;
net_type net(options);
dnn_trainer<net_type> trainer(net, sgd(0.1));
// Train the network. The loss is not supposed to go negative.
for (int i = 0; i < 100; ++i) {
print_spinner();
trainer.train_one_step({ input_image }, { labels });
DLIB_TEST(trainer.get_average_loss() >= 0.0);
}
// Inference should return something for the training data.
const auto dets = net(input_image);
DLIB_TEST(dets.size() > 0);
// Indeed many truth objects should be found.
const auto approximate_desired_det_count = (nr - 2 * margin) * (nc - 2 * margin) / 2.0;
DLIB_TEST(dets.size() > approximate_desired_det_count * 0.45);
DLIB_TEST(dets.size() < approximate_desired_det_count * 1.05);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class dnn_tester : public tester class dnn_tester : public tester
...@@ -3321,6 +3402,7 @@ namespace ...@@ -3321,6 +3402,7 @@ namespace
test_serialization(); test_serialization();
test_loss_dot(); test_loss_dot();
test_loss_multimulticlass_log(); test_loss_multimulticlass_log();
test_loss_mmod();
} }
void perform_test() void perform_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment