Unverified Commit 2ed6ae1b authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Eliminate grid sensitivity in YOLO (#2488)

parent 3da3e811
......@@ -3625,25 +3625,26 @@ namespace dlib
for (size_t a = 0; a < anchors.size(); ++a)
{
const long k = a * num_feats;
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
const float obj = out_data[tensor_index(output_tensor, n, a * num_feats + 4, r, c)];
const float obj = out_data[tensor_index(output_tensor, n, k + 4, r, c)];
if (obj > adjust_threshold)
{
const auto x = out_data[tensor_index(output_tensor, n, a * num_feats + 0, r, c)];
const auto y = out_data[tensor_index(output_tensor, n, a * num_feats + 1, r, c)];
const auto w = out_data[tensor_index(output_tensor, n, a * num_feats + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, a * num_feats + 3, r, c)];
const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
yolo_rect det(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
w / (1 - w) * anchors[a].width,
h / (1 - h) * anchors[a].height));
for (long k = 0; k < num_classes; ++k)
for (long i = 0; i < num_classes; ++i)
{
const float conf = obj * out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)];
const float conf = obj * out_data[tensor_index(output_tensor, n, k + 5 + i, r, c)];
if (conf > adjust_threshold)
det.labels.emplace_back(conf, options.labels[k]);
det.labels.emplace_back(conf, options.labels[i]);
}
if (!det.labels.empty())
{
......@@ -3692,18 +3693,16 @@ namespace dlib
{
for (size_t a = 0; a < anchors.size(); ++a)
{
const auto x_idx = tensor_index(output_tensor, n, a * num_feats + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, a * num_feats + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, a * num_feats + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, a * num_feats + 3, r, c);
const auto o_idx = tensor_index(output_tensor, n, a * num_feats + 4, r, c);
const long k = a * num_feats;
const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
// The prediction at r, c for anchor a
const yolo_rect pred(centered_drect(
dpoint((out_data[x_idx] + c) * stride_x, (out_data[y_idx] + r) * stride_y),
out_data[w_idx] / (1 - out_data[w_idx]) * anchors[a].width,
out_data[h_idx] / (1 - out_data[h_idx]) * anchors[a].height
));
const yolo_rect pred(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
w / (1 - w) * anchors[a].width,
h / (1 - h) * anchors[a].height));
// Find the best IoU for all ground truth boxes
double best_iou = 0;
......@@ -3715,6 +3714,7 @@ namespace dlib
}
// Incur loss for the boxes that are below a certain IoU threshold with any truth box
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
if (best_iou < options.iou_ignore_threshold)
g[o_idx] = options.lambda_obj * out_data[o_idx];
}
......@@ -3764,14 +3764,7 @@ namespace dlib
const long c = t_center.x() / stride_x;
const long r = t_center.y() / stride_y;
const auto x_idx = tensor_index(output_tensor, n, a * num_feats + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, a * num_feats + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, a * num_feats + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, a * num_feats + 3, r, c);
const auto o_idx = tensor_index(output_tensor, n, a * num_feats + 4, r, c);
// This grid cell should detect an object
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
const long k = a * num_feats;
// Get the truth box target values
const double tx = t_center.x() / stride_x - c;
......@@ -3783,16 +3776,24 @@ namespace dlib
const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area());
// Compute the gradient for the box coordinates
g[x_idx] = scale_box * (out_data[x_idx] - tx);
g[y_idx] = scale_box * (out_data[y_idx] - ty);
const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c);
g[x_idx] = scale_box * (out_data[x_idx] * 2.0 - 0.5 - tx);
g[y_idx] = scale_box * (out_data[y_idx] * 2.0 - 0.5 - ty);
g[w_idx] = scale_box * (out_data[w_idx] - tw);
g[h_idx] = scale_box * (out_data[h_idx] - th);
// This grid cell should detect an object
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
// Compute the classification error
for (long k = 0; k < num_classes; ++k)
for (long i = 0; i < num_classes; ++i)
{
const auto c_idx = tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c);
if (truth_box.label == options.labels[k])
const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c);
if (truth_box.label == options.labels[i])
g[c_idx] = options.lambda_cls * (out_data[c_idx] - 1);
else
g[c_idx] = options.lambda_cls * out_data[c_idx];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment