Commit 8e622e9f authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King
Browse files

Ignore truth rects that overlap too much (have same index in feature coordinates) (#1896)

* Add test case that makes MMOD loss go negative with certain ignore-rect configuration

* Disregard duplicate truth boxes

* Minor optimization

* Remove possibly outdated comment

* Clarify the detection count test criterion a little

* Review fix

* Review fixes:
- for perf reasons, keep only the first rect for each truth idx
- fix warning message grammar
parent cff605ad
......@@ -13,6 +13,7 @@
#include "../svm/ranking_tools.h"
#include <sstream>
#include <map>
#include <unordered_map>
namespace dlib
{
......@@ -1142,6 +1143,8 @@ namespace dlib
std::vector<size_t> truth_idxs;
truth_idxs.reserve(truth->size());
std::unordered_map<size_t, rectangle> idx_to_truth_rect;
// The loss will measure the number of incorrect detections. A detection is
// incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection
// on a truth rectangle.
......@@ -1156,20 +1159,33 @@ namespace dlib
{
// Ignore boxes that can't be detected by the CNN.
loss -= options.loss_per_missed_target;
truth_idxs.push_back(0);
truth_idxs.push_back(-1);
continue;
}
const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
const auto i = idx_to_truth_rect.find(idx);
if (i != idx_to_truth_rect.end())
{
// Ignore duplicate truth box in feature coordinates.
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << x.rect;
std::cout << ", and we are ignoring it because it maps to the exact same feature coordinates ";
std::cout << "as another truth rectangle located at " << i->second << "." << std::endl;
loss -= options.loss_per_missed_target;
truth_idxs.push_back(-1);
continue;
}
loss -= out_data[idx];
// compute gradient
g[idx] = -scale;
truth_idxs.push_back(idx);
idx_to_truth_rect[idx] = x.rect;
}
else
{
// This box was ignored so shouldn't have been counted in the loss.
loss -= options.loss_per_missed_target;
truth_idxs.push_back(0);
truth_idxs.push_back(-1);
}
}
......@@ -1226,6 +1242,8 @@ namespace dlib
if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
{
const size_t idx = truth_idxs[i];
if (idx != -1)
{
// We are ignoring this box so we shouldn't have counted it in the
// loss in the first place. So we subtract out the loss values we
// added for it in the code above.
......@@ -1239,6 +1257,7 @@ namespace dlib
}
}
}
}
hit_truth_table.assign(hit_truth_table.size(), false);
final_dets.clear();
......
......@@ -3231,6 +3231,87 @@ namespace
}
// ----------------------------------------------------------------------------------------
void test_loss_mmod()
{
print_spinner();
// Define input image size.
constexpr int nc = 20;
constexpr int nr = 20;
constexpr int margin = 3;
// Create a checkerboard pattern.
std::deque<point> labeled_points;
for (int y = margin; y < nr - margin; ++y)
for (int x = margin + 1 - y % 2; x < nc - margin; x += 2)
labeled_points.emplace_back(x, y);
// Create training data that follows the generated pattern.
typedef matrix<float> input_image_type;
const auto generate_input_image = [&labeled_points, nr, nc]()
{
input_image_type sample(nr, nc);
sample = -1.0;
for (const auto& point : labeled_points)
sample(point.y(), point.x()) = 1.0;
return sample;
};
const auto generate_labels = [&labeled_points]()
{
const auto point_to_rect = [](const point& point) {
constexpr int rect_size = 5;
return centered_rect(
point.x(), point.y(),
rect_size, rect_size
);
};
std::vector<mmod_rect> labels;
std::transform(
labeled_points.begin(),
labeled_points.end(),
std::back_inserter(labels),
point_to_rect
);
return labels;
};
const input_image_type input_image = generate_input_image();
const std::vector<mmod_rect> labels = generate_labels();
mmod_options options(use_image_pyramid::no, { labels });
// Define a simple network.
using net_type = loss_mmod<con<1,5,5,1,1,con<1,5,5,2,2,input<input_image_type>>>>;
net_type net(options);
dnn_trainer<net_type> trainer(net, sgd(0.1));
// Train the network. The loss is not supposed to go negative.
for (int i = 0; i < 100; ++i) {
print_spinner();
trainer.train_one_step({ input_image }, { labels });
DLIB_TEST(trainer.get_average_loss() >= 0.0);
}
// Inference should return something for the training data.
const auto dets = net(input_image);
DLIB_TEST(dets.size() > 0);
// Indeed many truth objects should be found.
const auto approximate_desired_det_count = (nr - 2 * margin) * (nc - 2 * margin) / 2.0;
DLIB_TEST(dets.size() > approximate_desired_det_count * 0.45);
DLIB_TEST(dets.size() < approximate_desired_det_count * 1.05);
}
// ----------------------------------------------------------------------------------------
class dnn_tester : public tester
......@@ -3321,6 +3402,7 @@ namespace
test_serialization();
test_loss_dot();
test_loss_multimulticlass_log();
test_loss_mmod();
}
void perform_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment