"tests/scripts/task_mxnet_tutorial_test.sh" did not exist on "7cbb83db9f10804666c28aacd6c0c61a3dcd4e48"
Unverified Commit a2c47603 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Use lambda obj in the ignore case and do some refactoring (#2466)

parent cf21f5aa
...@@ -3641,7 +3641,7 @@ namespace dlib ...@@ -3641,7 +3641,7 @@ namespace dlib
h / (1 - h) * anchors[a].height)); h / (1 - h) * anchors[a].height));
for (long k = 0; k < num_classes; ++k) for (long k = 0; k < num_classes; ++k)
{ {
const float conf = out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)] * obj; const float conf = obj * out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)];
if (conf > adjust_threshold) if (conf > adjust_threshold)
det.labels.emplace_back(conf, options.labels[k]); det.labels.emplace_back(conf, options.labels[k]);
} }
...@@ -3716,7 +3716,7 @@ namespace dlib ...@@ -3716,7 +3716,7 @@ namespace dlib
// Incur loss for the boxes that are below a certain IoU threshold with any truth box // Incur loss for the boxes that are below a certain IoU threshold with any truth box
if (best_iou < options.iou_ignore_threshold) if (best_iou < options.iou_ignore_threshold)
g[o_idx] = out_data[o_idx]; g[o_idx] = options.lambda_obj * out_data[o_idx];
} }
} }
} }
...@@ -3780,13 +3780,13 @@ namespace dlib ...@@ -3780,13 +3780,13 @@ namespace dlib
const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height()); const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height());
// Scale regression error according to the truth size // Scale regression error according to the truth size
const double scale_box = 2 - truth_box.rect.area() / input_area; const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_area);
// Compute the gradient for the box coordinates // Compute the gradient for the box coordinates
g[x_idx] = options.lambda_box * scale_box * (out_data[x_idx] - tx); g[x_idx] = scale_box * (out_data[x_idx] - tx);
g[y_idx] = options.lambda_box * scale_box * (out_data[y_idx] - ty); g[y_idx] = scale_box * (out_data[y_idx] - ty);
g[w_idx] = options.lambda_box * scale_box * (out_data[w_idx] - tw); g[w_idx] = scale_box * (out_data[w_idx] - tw);
g[h_idx] = options.lambda_box * scale_box * (out_data[h_idx] - th); g[h_idx] = scale_box * (out_data[h_idx] - th);
// Compute the classification error // Compute the classification error
for (long k = 0; k < num_classes; ++k) for (long k = 0; k < num_classes; ++k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment