Unverified Commit 797fd26f authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use pytorch smooth_l1_loss and remove private custom implem (#3539)

parent 551193b4
...@@ -344,19 +344,6 @@ class Matcher(object): ...@@ -344,19 +344,6 @@ class Matcher(object):
matches[pred_inds_to_update] = all_matches[pred_inds_to_update] matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()
def overwrite_eps(model, eps): def overwrite_eps(model, eps):
""" """
This method overwrites the default eps values of all the This method overwrites the default eps values of all the
......
...@@ -42,11 +42,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): ...@@ -42,11 +42,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
N, num_classes = class_logits.shape N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
box_loss = det_utils.smooth_l1_loss( box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos], box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset], regression_targets[sampled_pos_inds_subset],
beta=1 / 9, beta=1 / 9,
size_average=False, reduction='sum',
) )
box_loss = box_loss / labels.numel() box_loss = box_loss / labels.numel()
......
...@@ -304,11 +304,11 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -304,11 +304,11 @@ class RegionProposalNetwork(torch.nn.Module):
labels = torch.cat(labels, dim=0) labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0) regression_targets = torch.cat(regression_targets, dim=0)
box_loss = det_utils.smooth_l1_loss( box_loss = F.smooth_l1_loss(
pred_bbox_deltas[sampled_pos_inds], pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds], regression_targets[sampled_pos_inds],
beta=1 / 9, beta=1 / 9,
size_average=False, reduction='sum',
) / (sampled_inds.numel()) ) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits( objectness_loss = F.binary_cross_entropy_with_logits(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment