Unverified Commit ba26d8ce authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Refactoring for losses (#761)

* refactoring for losses

* update configs for guided anchoring

* add all imported losses to __all__

* allow weight=None for binary_cross_entropy

* use losses in mmdetection for FCOSHead

* bug fix for weight_reduce_loss

* add eps to iou_loss and handle weight=None

* unify loss api in FCOSHead

* fix avg_factor
parent afb7ec86
import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import (weighted_cross_entropy, weighted_binary_cross_entropy, import torch.nn.functional as F
mask_cross_entropy)
from .utils import weight_reduce_loss, weighted_loss
from ..registry import LOSSES from ..registry import LOSSES
cross_entropy = weighted_loss(F.cross_entropy)
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
if label_weights is None:
bin_label_weights = None
else:
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
# element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight, reduction='none')
# apply weights and do the reduction
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None):
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='mean')[None]
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
def __init__(self, use_sigmoid=False, use_mask=False, loss_weight=1.0): def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
loss_weight=1.0):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False) assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid self.use_sigmoid = use_sigmoid
self.use_mask = use_mask self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
if self.use_sigmoid: if self.use_sigmoid:
self.cls_criterion = weighted_binary_cross_entropy self.cls_criterion = binary_cross_entropy
elif self.use_mask: elif self.use_mask:
self.cls_criterion = mask_cross_entropy self.cls_criterion = mask_cross_entropy
else: else:
self.cls_criterion = weighted_cross_entropy self.cls_criterion = cross_entropy
def forward(self, cls_score, label, label_weight, *args, **kwargs): def forward(self, cls_score, label, weight=None, avg_factor=None,
**kwargs):
loss_cls = self.loss_weight * self.cls_criterion( loss_cls = self.loss_weight * self.cls_criterion(
cls_score, label, label_weight, *args, **kwargs) cls_score,
label,
weight,
reduction=self.reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls return loss_cls
import torch.nn as nn import torch.nn as nn
from mmdet.core import weighted_sigmoid_focal_loss import torch.nn.functional as F
from mmdet.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from .utils import weight_reduce_loss
from ..registry import LOSSES from ..registry import LOSSES
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred, target, gamma, alpha)
# TODO: find a proper way to handle the shape of weight
if weight is not None:
weight = weight.view(-1, 1)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module @LOSSES.register_module
class FocalLoss(nn.Module): class FocalLoss(nn.Module):
def __init__(self, def __init__(self,
use_sigmoid=False, use_sigmoid=True,
loss_weight=1.0,
gamma=2.0, gamma=2.0,
alpha=0.25): alpha=0.25,
reduction='mean',
loss_weight=1.0):
super(FocalLoss, self).__init__() super(FocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid self.use_sigmoid = use_sigmoid
self.loss_weight = loss_weight
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.cls_criterion = weighted_sigmoid_focal_loss self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, cls_score, label, label_weight, *args, **kwargs): def forward(self, pred, target, weight=None, avg_factor=None):
if self.use_sigmoid: if self.use_sigmoid:
loss_cls = self.loss_weight * self.cls_criterion( loss_cls = self.loss_weight * sigmoid_focal_loss(
cls_score, pred,
label, target,
label_weight, weight,
gamma=self.gamma, gamma=self.gamma,
alpha=self.alpha, alpha=self.alpha,
*args, reduction=self.reduction,
**kwargs) avg_factor=avg_factor)
else: else:
raise NotImplementedError raise NotImplementedError
return loss_cls return loss_cls
...@@ -29,12 +29,8 @@ class GHMC(nn.Module): ...@@ -29,12 +29,8 @@ class GHMC(nn.Module):
use_sigmoid (bool): Can only be true for BCE based loss now. use_sigmoid (bool): Can only be true for BCE based loss now.
loss_weight (float): The weight of the total GHM-C loss. loss_weight (float): The weight of the total GHM-C loss.
""" """
def __init__(
self, def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0):
bins=10,
momentum=0,
use_sigmoid=True,
loss_weight=1.0):
super(GHMC, self).__init__() super(GHMC, self).__init__()
self.bins = bins self.bins = bins
self.momentum = momentum self.momentum = momentum
...@@ -76,7 +72,7 @@ class GHMC(nn.Module): ...@@ -76,7 +72,7 @@ class GHMC(nn.Module):
tot = max(valid.float().sum().item(), 1.0) tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins n = 0 # n valid bins
for i in range(self.bins): for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
num_in_bin = inds.sum().item() num_in_bin = inds.sum().item()
if num_in_bin > 0: if num_in_bin > 0:
if mmt > 0: if mmt > 0:
...@@ -108,12 +104,8 @@ class GHMR(nn.Module): ...@@ -108,12 +104,8 @@ class GHMR(nn.Module):
momentum (float): The parameter for moving average. momentum (float): The parameter for moving average.
loss_weight (float): The weight of the total GHM-R loss. loss_weight (float): The weight of the total GHM-R loss.
""" """
def __init__(
self, def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0):
mu=0.02,
bins=10,
momentum=0,
loss_weight=1.0):
super(GHMR, self).__init__() super(GHMR, self).__init__()
self.mu = mu self.mu = mu
self.bins = bins self.bins = bins
...@@ -154,7 +146,7 @@ class GHMR(nn.Module): ...@@ -154,7 +146,7 @@ class GHMR(nn.Module):
tot = max(label_weight.float().sum().item(), 1.0) tot = max(label_weight.float().sum().item(), 1.0)
n = 0 # n: valid bins n = 0 # n: valid bins
for i in range(self.bins): for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
num_in_bin = inds.sum().item() num_in_bin = inds.sum().item()
if num_in_bin > 0: if num_in_bin > 0:
n += 1 n += 1
......
import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import weighted_iou_loss
from mmdet.core import bbox_overlaps
from .utils import weighted_loss
from ..registry import LOSSES from ..registry import LOSSES
@weighted_loss
def iou_loss(pred, target, eps=1e-6):
"""IoU loss.
Computing the IoU loss between a set of predicted bboxes and target bboxes.
The loss is calculated as negative log of IoU.
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
eps (float): Eps to avoid log(0).
Return:
Tensor: Loss tensor.
"""
ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
loss = -ious.log()
return loss
@weighted_loss
def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
"""Improving Object Localization with Fitness NMS and Bounded IoU Loss,
https://arxiv.org/abs/1711.00164.
Args:
pred (tensor): Predicted bboxes.
target (tensor): Target bboxes.
beta (float): beta parameter in smoothl1.
eps (float): eps to avoid NaN.
"""
pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
pred_w = pred[:, 2] - pred[:, 0] + 1
pred_h = pred[:, 3] - pred[:, 1] + 1
with torch.no_grad():
target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
target_ctry = (target[:, 1] + target[:, 3]) * 0.5
target_w = target[:, 2] - target[:, 0] + 1
target_h = target[:, 3] - target[:, 1] + 1
dx = target_ctrx - pred_ctrx
dy = target_ctry - pred_ctry
loss_dx = 1 - torch.max(
(target_w - 2 * dx.abs()) /
(target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
loss_dy = 1 - torch.max(
(target_h - 2 * dy.abs()) /
(target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
(target_w + eps))
loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
(target_h + eps))
loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
dim=-1).view(loss_dx.size(0), -1)
loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
loss_comb - 0.5 * beta)
return loss
@LOSSES.register_module @LOSSES.register_module
class IoULoss(nn.Module): class IoULoss(nn.Module):
def __init__(self, style='naive', beta=0.2, eps=1e-3, loss_weight=1.0): def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
super(IoULoss, self).__init__() super(IoULoss, self).__init__()
self.style = style self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
if weight is not None and not torch.any(weight > 0):
return (pred * weight).sum() # 0
loss = self.loss_weight * iou_loss(
pred,
target,
weight,
eps=self.eps,
reduction=self.reduction,
avg_factor=avg_factor,
**kwargs)
return loss
@LOSSES.register_module
class BoundedIoULoss(nn.Module):
def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0):
super(BoundedIoULoss, self).__init__()
self.beta = beta self.beta = beta
self.eps = eps self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, weight, *args, **kwargs): def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
loss = self.loss_weight * weighted_iou_loss( if weight is not None and not torch.any(weight > 0):
return (pred * weight).sum() # 0
loss = self.loss_weight * bounded_iou_loss(
pred, pred,
target, target,
weight, weight,
style=self.style,
beta=self.beta, beta=self.beta,
eps=self.eps, eps=self.eps,
*args, reduction=self.reduction,
avg_factor=avg_factor,
**kwargs) **kwargs)
return loss return loss
import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import weighted_smoothl1
from .utils import weighted_loss
from ..registry import LOSSES from ..registry import LOSSES
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
assert beta > 0
assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
return loss
@LOSSES.register_module @LOSSES.register_module
class SmoothL1Loss(nn.Module): class SmoothL1Loss(nn.Module):
def __init__(self, beta=1.0, loss_weight=1.0): def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
super(SmoothL1Loss, self).__init__() super(SmoothL1Loss, self).__init__()
self.beta = beta self.beta = beta
self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, weight, *args, **kwargs): def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
loss_bbox = self.loss_weight * weighted_smoothl1( loss_bbox = self.loss_weight * smooth_l1_loss(
pred, target, weight, beta=self.beta, *args, **kwargs) pred,
target,
weight,
beta=self.beta,
reduction=self.reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox return loss_bbox
import functools
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Avarage factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
# otherwise average the loss by avg_factor
else:
if reduction != 'mean':
raise ValueError(
'avg_factor can only be used with reduction="mean"')
loss = loss.sum() / avg_factor
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
import torch.nn.functional as F
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
...@@ -8,7 +7,7 @@ from .. import sigmoid_focal_loss_cuda ...@@ -8,7 +7,7 @@ from .. import sigmoid_focal_loss_cuda
class SigmoidFocalLossFunction(Function): class SigmoidFocalLossFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, target, gamma=2.0, alpha=0.25, reduction='mean'): def forward(ctx, input, target, gamma=2.0, alpha=0.25):
ctx.save_for_backward(input, target) ctx.save_for_backward(input, target)
num_classes = input.shape[1] num_classes = input.shape[1]
ctx.num_classes = num_classes ctx.num_classes = num_classes
...@@ -17,14 +16,7 @@ class SigmoidFocalLossFunction(Function): ...@@ -17,14 +16,7 @@ class SigmoidFocalLossFunction(Function):
loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes, loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes,
gamma, alpha) gamma, alpha)
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@staticmethod @staticmethod
@once_differentiable @once_differentiable
......
...@@ -3,6 +3,7 @@ from torch import nn ...@@ -3,6 +3,7 @@ from torch import nn
from ..functions.sigmoid_focal_loss import sigmoid_focal_loss from ..functions.sigmoid_focal_loss import sigmoid_focal_loss
# TODO: remove this module
class SigmoidFocalLoss(nn.Module): class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha): def __init__(self, gamma, alpha):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment