Unverified Commit a2c47603 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Use lambda obj in the ignore case and do some refactoring (#2466)

parent cf21f5aa
...@@ -3641,7 +3641,7 @@ namespace dlib ...@@ -3641,7 +3641,7 @@ namespace dlib
h / (1 - h) * anchors[a].height)); h / (1 - h) * anchors[a].height));
for (long k = 0; k < num_classes; ++k) for (long k = 0; k < num_classes; ++k)
{ {
const float conf = out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)] * obj; const float conf = obj * out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)];
if (conf > adjust_threshold) if (conf > adjust_threshold)
det.labels.emplace_back(conf, options.labels[k]); det.labels.emplace_back(conf, options.labels[k]);
} }
...@@ -3716,7 +3716,7 @@ namespace dlib ...@@ -3716,7 +3716,7 @@ namespace dlib
// Incur loss for the boxes that are below a certain IoU threshold with any truth box // Incur loss for the boxes that are below a certain IoU threshold with any truth box
if (best_iou < options.iou_ignore_threshold) if (best_iou < options.iou_ignore_threshold)
g[o_idx] = out_data[o_idx]; g[o_idx] = options.lambda_obj * out_data[o_idx];
} }
} }
} }
...@@ -3780,13 +3780,13 @@ namespace dlib ...@@ -3780,13 +3780,13 @@ namespace dlib
const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height()); const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height());
// Scale regression error according to the truth size // Scale regression error according to the truth size
const double scale_box = 2 - truth_box.rect.area() / input_area; const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_area);
// Compute the gradient for the box coordinates // Compute the gradient for the box coordinates
g[x_idx] = options.lambda_box * scale_box * (out_data[x_idx] - tx); g[x_idx] = scale_box * (out_data[x_idx] - tx);
g[y_idx] = options.lambda_box * scale_box * (out_data[y_idx] - ty); g[y_idx] = scale_box * (out_data[y_idx] - ty);
g[w_idx] = options.lambda_box * scale_box * (out_data[w_idx] - tw); g[w_idx] = scale_box * (out_data[w_idx] - tw);
g[h_idx] = options.lambda_box * scale_box * (out_data[h_idx] - th); g[h_idx] = scale_box * (out_data[h_idx] - th);
// Compute the classification error // Compute the classification error
for (long k = 0; k < num_classes; ++k) for (long k = 0; k < num_classes; ++k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment