Commit 32df98e9 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Add reduction_override flag (#839)

* add reduction_override flag

* change default value of reduction_override as None

* add assertion, fix format

* delete redudant statement in util

* delete redudant comment
parent fc0172b4
......@@ -36,7 +36,7 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
reduction_override='none')['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]
......
......@@ -97,12 +97,16 @@ class BBoxHead(nn.Module):
label_weights,
bbox_targets,
bbox_weights,
reduce=True):
reduction_override=None):
losses = dict()
if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
losses['loss_cls'] = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
cls_score,
labels,
label_weights,
avg_factor=avg_factor,
reduction_override=reduction_override)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
pos_inds = labels > 0
......@@ -115,7 +119,8 @@ class BBoxHead(nn.Module):
pos_bbox_pred,
bbox_targets[pos_inds],
bbox_weights[pos_inds],
avg_factor=bbox_targets.size(0))
avg_factor=bbox_targets.size(0),
reduction_override=reduction_override)
return losses
def get_det_bboxes(self,
......
......@@ -46,7 +46,16 @@ class BalancedL1Loss(nn.Module):
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * balanced_l1_loss(
pred,
target,
......@@ -54,7 +63,7 @@ class BalancedL1Loss(nn.Module):
alpha=self.alpha,
gamma=self.gamma,
beta=self.beta,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox
......@@ -73,13 +73,21 @@ class CrossEntropyLoss(nn.Module):
else:
self.cls_criterion = cross_entropy
def forward(self, cls_score, label, weight=None, avg_factor=None,
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
......@@ -59,7 +59,15 @@ class FocalLoss(nn.Module):
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
loss_cls = self.loss_weight * sigmoid_focal_loss(
pred,
......@@ -67,7 +75,7 @@ class FocalLoss(nn.Module):
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
......
......@@ -15,6 +15,7 @@ def _expand_binary_labels(labels, label_weights, label_channels):
return bin_labels, bin_label_weights
# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module
class GHMC(nn.Module):
"""GHM Classification Loss.
......@@ -90,6 +91,7 @@ class GHMC(nn.Module):
return loss * self.loss_weight
# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module
class GHMR(nn.Module):
"""GHM Regression Loss.
......@@ -116,6 +118,7 @@ class GHMR(nn.Module):
self.acc_sum = torch.zeros(bins).cuda()
self.loss_weight = loss_weight
# TODO: support reduction parameter
def forward(self, pred, target, label_weight, avg_factor=None):
"""Calculate the GHM-R loss.
......
......@@ -78,15 +78,24 @@ class IoULoss(nn.Module):
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
if weight is not None and not torch.any(weight > 0):
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * iou_loss(
pred,
target,
weight,
eps=self.eps,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
......
......@@ -24,13 +24,22 @@ class SmoothL1Loss(nn.Module):
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * smooth_l1_loss(
pred,
target,
weight,
beta=self.beta,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox
......@@ -42,12 +42,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
# otherwise average the loss by avg_factor
else:
if reduction != 'mean':
raise ValueError(
'avg_factor can only be used with reduction="mean"')
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment