Unverified Commit 2ed6ae1b authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Eliminate grid sensitivity in YOLO (#2488)

parent 3da3e811
...@@ -3625,25 +3625,26 @@ namespace dlib ...@@ -3625,25 +3625,26 @@ namespace dlib
for (size_t a = 0; a < anchors.size(); ++a) for (size_t a = 0; a < anchors.size(); ++a)
{ {
const long k = a * num_feats;
for (long r = 0; r < output_tensor.nr(); ++r) for (long r = 0; r < output_tensor.nr(); ++r)
{ {
for (long c = 0; c < output_tensor.nc(); ++c) for (long c = 0; c < output_tensor.nc(); ++c)
{ {
const float obj = out_data[tensor_index(output_tensor, n, a * num_feats + 4, r, c)]; const float obj = out_data[tensor_index(output_tensor, n, k + 4, r, c)];
if (obj > adjust_threshold) if (obj > adjust_threshold)
{ {
const auto x = out_data[tensor_index(output_tensor, n, a * num_feats + 0, r, c)]; const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const auto y = out_data[tensor_index(output_tensor, n, a * num_feats + 1, r, c)]; const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const auto w = out_data[tensor_index(output_tensor, n, a * num_feats + 2, r, c)]; const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, a * num_feats + 3, r, c)]; const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
yolo_rect det(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y), yolo_rect det(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
w / (1 - w) * anchors[a].width, w / (1 - w) * anchors[a].width,
h / (1 - h) * anchors[a].height)); h / (1 - h) * anchors[a].height));
for (long k = 0; k < num_classes; ++k) for (long i = 0; i < num_classes; ++i)
{ {
const float conf = obj * out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)]; const float conf = obj * out_data[tensor_index(output_tensor, n, k + 5 + i, r, c)];
if (conf > adjust_threshold) if (conf > adjust_threshold)
det.labels.emplace_back(conf, options.labels[k]); det.labels.emplace_back(conf, options.labels[i]);
} }
if (!det.labels.empty()) if (!det.labels.empty())
{ {
...@@ -3692,18 +3693,16 @@ namespace dlib ...@@ -3692,18 +3693,16 @@ namespace dlib
{ {
for (size_t a = 0; a < anchors.size(); ++a) for (size_t a = 0; a < anchors.size(); ++a)
{ {
const auto x_idx = tensor_index(output_tensor, n, a * num_feats + 0, r, c); const long k = a * num_feats;
const auto y_idx = tensor_index(output_tensor, n, a * num_feats + 1, r, c); const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const auto w_idx = tensor_index(output_tensor, n, a * num_feats + 2, r, c); const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const auto h_idx = tensor_index(output_tensor, n, a * num_feats + 3, r, c); const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const auto o_idx = tensor_index(output_tensor, n, a * num_feats + 4, r, c); const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
// The prediction at r, c for anchor a // The prediction at r, c for anchor a
const yolo_rect pred(centered_drect( const yolo_rect pred(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
dpoint((out_data[x_idx] + c) * stride_x, (out_data[y_idx] + r) * stride_y), w / (1 - w) * anchors[a].width,
out_data[w_idx] / (1 - out_data[w_idx]) * anchors[a].width, h / (1 - h) * anchors[a].height));
out_data[h_idx] / (1 - out_data[h_idx]) * anchors[a].height
));
// Find the best IoU for all ground truth boxes // Find the best IoU for all ground truth boxes
double best_iou = 0; double best_iou = 0;
...@@ -3715,6 +3714,7 @@ namespace dlib ...@@ -3715,6 +3714,7 @@ namespace dlib
} }
// Incur loss for the boxes that are below a certain IoU threshold with any truth box // Incur loss for the boxes that are below a certain IoU threshold with any truth box
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
if (best_iou < options.iou_ignore_threshold) if (best_iou < options.iou_ignore_threshold)
g[o_idx] = options.lambda_obj * out_data[o_idx]; g[o_idx] = options.lambda_obj * out_data[o_idx];
} }
...@@ -3764,14 +3764,7 @@ namespace dlib ...@@ -3764,14 +3764,7 @@ namespace dlib
const long c = t_center.x() / stride_x; const long c = t_center.x() / stride_x;
const long r = t_center.y() / stride_y; const long r = t_center.y() / stride_y;
const auto x_idx = tensor_index(output_tensor, n, a * num_feats + 0, r, c); const long k = a * num_feats;
const auto y_idx = tensor_index(output_tensor, n, a * num_feats + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, a * num_feats + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, a * num_feats + 3, r, c);
const auto o_idx = tensor_index(output_tensor, n, a * num_feats + 4, r, c);
// This grid cell should detect an object
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
// Get the truth box target values // Get the truth box target values
const double tx = t_center.x() / stride_x - c; const double tx = t_center.x() / stride_x - c;
...@@ -3783,16 +3776,24 @@ namespace dlib ...@@ -3783,16 +3776,24 @@ namespace dlib
const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area()); const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area());
// Compute the gradient for the box coordinates // Compute the gradient for the box coordinates
g[x_idx] = scale_box * (out_data[x_idx] - tx); const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c);
g[y_idx] = scale_box * (out_data[y_idx] - ty); const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c);
g[x_idx] = scale_box * (out_data[x_idx] * 2.0 - 0.5 - tx);
g[y_idx] = scale_box * (out_data[y_idx] * 2.0 - 0.5 - ty);
g[w_idx] = scale_box * (out_data[w_idx] - tw); g[w_idx] = scale_box * (out_data[w_idx] - tw);
g[h_idx] = scale_box * (out_data[h_idx] - th); g[h_idx] = scale_box * (out_data[h_idx] - th);
// This grid cell should detect an object
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
// Compute the classification error // Compute the classification error
for (long k = 0; k < num_classes; ++k) for (long i = 0; i < num_classes; ++i)
{ {
const auto c_idx = tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c); const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c);
if (truth_box.label == options.labels[k]) if (truth_box.label == options.labels[i])
g[c_idx] = options.lambda_cls * (out_data[c_idx] - 1); g[c_idx] = options.lambda_cls * (out_data[c_idx] - 1);
else else
g[c_idx] = options.lambda_cls * out_data[c_idx]; g[c_idx] = options.lambda_cls * out_data[c_idx];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment