Unverified Commit 8df3f29a authored by Potter Hsu's avatar Potter Hsu Committed by GitHub
Browse files

Replace L1 loss with beta smooth L1 loss to achieve better performance (#2113)

* For detection model, replace L1 loss with beta smooth L1 loss to achieve better performance.

* Add type annotations for torchscript

* Resolve E226
parent c031e287
......@@ -346,3 +346,16 @@ class Matcher(object):
pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()
......@@ -43,10 +43,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, -1, 4)
box_loss = F.smooth_l1_loss(
box_loss = det_utils.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
reduction="sum",
beta=1 / 9,
size_average=False,
)
box_loss = box_loss / labels.numel()
......
......@@ -440,10 +440,11 @@ class RegionProposalNetwork(torch.nn.Module):
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = F.l1_loss(
box_loss = det_utils.smooth_l1_loss(
pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds],
reduction="sum",
beta=1 / 9,
size_average=False,
) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment