Commit dbed904c authored by mibaumgartner's avatar mibaumgartner
Browse files

losses

parent 2cebb1a0
from nndet.losses.classification import focal_loss_with_logits, FocalLossWithLogits
from nndet.losses.regression import SmoothL1Loss, smooth_l1_loss, GIoULoss
from nndet.losses.segmentation import SoftDiceLoss
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
__all__ = ["reduction_helper"]
def reduction_helper(data: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Helper to collapse data with different modes
Args:
data: data to collapse
reduction: type of reduction. One of `mean`, `sum`, None
Returns:
Tensor: reduced data
"""
if reduction == 'mean':
return torch.mean(data)
if reduction == 'none' or reduction is None:
return data
if reduction == 'sum':
return torch.sum(data)
raise AttributeError('Reduction parameter unknown.')
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from loguru import logger
from nndet.losses.base import reduction_helper
from nndet.utils import make_onehot_batch
__all__ = ["focal_loss_with_logits", "FocalLossWithLogits"]
def one_hot_smooth(data,
num_classes: int,
smoothing: float = 0.0,
):
targets = torch.empty(size=(*data.shape, num_classes), device=data.device)\
.fill_(smoothing / num_classes)\
.scatter_(-1, data.long().unsqueeze(-1), 1. - smoothing)
return targets
@torch.jit.script
def focal_loss_with_logits(
logits: torch.Tensor,
target: torch.Tensor, gamma: float,
alpha: float = -1,
reduction: str = "mean",
) -> torch.Tensor:
"""
Focal loss
https://arxiv.org/abs/1708.02002
Args:
logits: predicted logits [N, dims]
target: (float) binary targets [N, dims]
gamma: balance easy and hard examples in focal loss
alpha: balance positive and negative samples [0, 1] (increasing
alpha increase weight of foreground classes (better recall))
reduction: 'mean'|'sum'|'none'
mean: mean of loss over entire batch
sum: sum of loss over entire batch
none: no reduction
Returns:
torch.Tensor: loss
See Also
:class:`BFocalLossWithLogits`, :class:`FocalLossWithLogits`
"""
bce_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
p = torch.sigmoid(logits)
pt = (p * target + (1 - p) * (1 - target))
focal_term = (1. - pt).pow(gamma)
loss = focal_term * bce_loss
if alpha >= 0:
alpha_t = (alpha * target + (1 - alpha) * (1 - target))
loss = alpha_t * loss
return reduction_helper(loss, reduction=reduction)
class FocalLossWithLogits(nn.Module):
def __init__(self,
gamma: float = 2,
alpha: float = -1,
reduction: str = "sum",
loss_weight: float = 1.,
):
"""
Focal loss with multiple classes (uses one hot encoding and sigmoid)
Args:
gamma: balance easy and hard examples in focal loss
alpha: balance positive and negative samples [0, 1] (increasing
alpha increase weight of foreground classes (better recall))
reduction: 'mean'|'sum'|'none'
mean: mean of loss over entire batch
sum: sum of loss over entire batch
none: no reduction
loss_weight: scalar to balance multiple losses
"""
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
logits: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor:
"""
Compute loss
Args:
logits: predicted logits [N, C, dims], where N is the batch size,
C number of classes, dims are arbitrary spatial dimensions
(background classes should be located at channel 0 if
ignore background is enabled)
targets: targets encoded as numbers [N, dims], where N is the
batch size, dims are arbitrary spatial dimensions
Returns:
torch.Tensor: loss
"""
n_classes = logits.shape[1] + 1
target_onehot = make_onehot_batch(targets, n_classes=n_classes).float()
target_onehot = target_onehot[:, 1:]
return self.loss_weight * focal_loss_with_logits(
logits, target_onehot,
gamma=self.gamma,
alpha=self.alpha,
reduction=self.reduction,
)
class BCEWithLogitsLossOneHot(torch.nn.BCEWithLogitsLoss):
def __init__(self,
*args,
num_classes: int,
smoothing: float = 0.0,
loss_weight: float = 1.,
**kwargs,
):
"""
BCE loss with one hot encoding of targets
Args:
num_classes: number of classes
smoothing: label smoothing
loss_weight: scalar to balance multiple losses
"""
super().__init__(*args, **kwargs)
self.smoothing = smoothing
if smoothing > 0:
logger.info(f"Running label smoothing with smoothing: {smoothing}")
self.num_classes = num_classes
self.loss_weight = loss_weight
def forward(self,
input: Tensor,
target: Tensor,
) -> Tensor:
"""
Compute bce loss based on one hot encoding
Args:
input: logits for all foreground classes [N, C]
N is the number of anchors, and C is the number of foreground
classes
target: target classes. 0 is treated as background, >0 are
treated as foreground classes. [N] is the number of anchors
Returns:
Tensor: final loss
"""
target_one_hot = one_hot_smooth(
target, num_classes=self.num_classes + 1, smoothing=self.smoothing) # [N, C + 1]
target_one_hot = target_one_hot[:, 1:] # background is implicitly encoded
return self.loss_weight * super().forward(input, target_one_hot.float())
class CrossEntropyLoss(torch.nn.CrossEntropyLoss):
def __init__(self,
*args,
loss_weight: float = 1.,
**kwargs,
) -> None:
"""
Same as CE from pytorch with additional loss weight for uniform API
"""
super().__init__(*args, **kwargs)
self.loss_weight = loss_weight
def forward(self,
input: Tensor,
target: Tensor,
) -> Tensor:
"""
Same as CE from pytorch
"""
return self.loss_weight * super().forward(input, target)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional
import torch
__all__ = ["SmoothL1Loss", "smooth_l1_loss"]
from nndet.detection.boxes.utils import generalized_box_iou
from nndet.losses.base import reduction_helper
class SmoothL1Loss(torch.nn.Module):
def __init__(self,
beta: float,
reduction: Optional[str] = None,
loss_weight: float = 1.,
):
"""
Module wrapper for functional
Args:
beta (float): L1 to L2 change point.
For beta values < 1e-5, L1 loss is computed.
reduction (str): 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
See Also:
:func:`smooth_l1_loss`
"""
super().__init__()
self.reduction = reduction
self.beta = beta
self.loss_weight = loss_weight
def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss
Args:
inp (torch.Tensor): predicted tensor (same shape as target)
target (torch.Tensor): target tensor
Returns:
Tensor: computed loss
"""
return self.loss_weight * reduction_helper(smooth_l1_loss(inp, target, self.beta), self.reduction)
def smooth_l1_loss(inp, target, beta: float):
"""
From https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/smooth_l1_loss.py
Smooth L1 loss defined in the Fast R-CNN paper as:
| 0.5 * x ** 2 / beta if abs(x) < beta
smoothl1(x) = |
| abs(x) - 0.5 * beta otherwise,
where x = input - target.
Smooth L1 loss is related to Huber loss, which is defined as:
| 0.5 * x ** 2 if abs(x) < beta
huber(x) = |
| beta * (abs(x) - 0.5 * beta) otherwise
Smooth L1 loss is equal to huber(x) / beta. This leads to the following
differences:
- As beta -> 0, Smooth L1 loss converges to L1 loss, while Huber loss
converges to a constant 0 loss.
- As beta -> +inf, Smooth L1 converges to a constant 0 loss, while Huber loss
converges to L2 loss.
- For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant
slope of 1. For Huber loss, the slope of the L1 segment is beta.
Smooth L1 loss can be seen as exactly L1 loss, but with the abs(x) < beta
portion replaced with a quadratic function such that at abs(x) = beta, its
slope is 1. The quadratic segment smooths the L1 loss near x = 0.
Args:
inp (Tensor): input tensor of any shape
target (Tensor): target value tensor with the same shape as input
beta (float): L1 to L2 change point.
For beta values < 1e-5, L1 loss is computed.
reduction (str): 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Tensor: The loss with the reduction option applied.
Note:
PyTorch's builtin "Smooth L1 loss" implementation does not actually
implement Smooth L1 loss, nor does it implement Huber loss. It implements
the special case of both in which they are equal (beta=1).
See: https://pytorch.org/docs/stable/nn.html#torch.nn.SmoothL1Loss.
"""
if beta < 1e-5:
# if beta == 0, then torch.where will result in nan gradients when
# the chain rule is applied due to pytorch implementation details
# (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
# zeros, rather than "no gradient"). To avoid this issue, we define
# small values of beta to be exactly l1 loss.
loss = torch.abs(inp - target)
else:
n = torch.abs(inp - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
return loss
class GIoULoss(torch.nn.Module):
def __init__(self,
reduction: Optional[str] = None,
eps: float = 1e-7,
loss_weight: float = 1.,
):
"""
Generalized IoU Loss
`Generalized Intersection over Union: A Metric and A Loss for Bounding
Box Regression` https://arxiv.org/abs/1902.09630
Args:
eps: small constant for numerical stability
Notes:
Original paper uses lambda=10 to balance regression and cls losses
for PASCAL VOC and COCO (not tuned for coco)
`End-to-End Object Detection with Transformers` https://arxiv.org/abs/2005.12872
"Our enhanced Faster-RCNN+ baselines use GIoU [38] loss along with
the standard l1 loss for bounding box regression. We performed a grid search
to find the best weights for the losses and the final models use only GIoU loss
with weights 20 and 1 for box and proposal regression tasks respectively"
"""
super().__init__()
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred_boxes: torch.Tensor, target_boxes: torch.Tensor) -> torch.Tensor:
"""
Compute generalized iou loss
Args:
pred_boxes: predicted boxes (x1, y1, x2, y2, (z1, z2)) [N, dim * 2]
target_boxes: target boxes (x1, y1, x2, y2, (z1, z2)) [N, dim * 2]
Returns:
Tensor: loss
"""
loss = reduction_helper(
torch.diag(generalized_box_iou(pred_boxes, target_boxes, eps=self.eps),
diagonal=0),
reduction=self.reduction)
return self.loss_weight * -1 * loss
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from loguru import logger
import torch
import torch.nn as nn
from torch import Tensor
from typing import Callable
__all__ = ["SoftDiceLoss"]
def one_hot_smooth_batch(data, num_classes: int, smoothing: float = 0.0):
shape = data.shape
targets = torch.empty(size=(shape[0], num_classes, *shape[1:]), device=data.device)\
.fill_(smoothing / num_classes)\
.scatter_(1, data.long().unsqueeze(1), 1. - smoothing)
return targets
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = tp.sum(dim=axes, keepdim=False)
fp = fp.sum(dim=axes, keepdim=False)
fn = fn.sum(dim=axes, keepdim=False)
return tp, fp, fn
class SoftDiceLoss(nn.Module):
def __init__(self,
nonlin: Callable = None,
batch_dice: bool = False,
do_bg: bool = False,
smooth_nom: float = 1e-5,
smooth_denom: float = 1e-5,
):
"""
Soft dice loss
Args:
nonlin: treat batch as pseudo volume. Defaults to False.
do_bg: include background for dice computation. Defaults to True.
smooth_nom: smoothing for nominator
smooth_denom: smoothing for denominator
"""
super().__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.nonlin = nonlin
self.smooth_nom = smooth_nom
self.smooth_denom = smooth_denom
logger.info(f"Running batch dice {self.batch_dice} and "
f"do bg {self.do_bg} in dice loss.")
def forward(self,
inp: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor=None,
):
"""
Compute loss
Args:
inp (torch.Tensor): predictions
target (torch.Tensor): ground truth
loss_mask ([torch.Tensor], optional): binary mask. Defaults to None.
Returns:
torch.Tensor: soft dice loss
"""
shp_x = inp.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.nonlin is not None:
inp = self.nonlin(inp)
tp, fp, fn = get_tp_fp_fn(inp, target, axes, loss_mask, False)
nominator = 2 * tp + self.smooth_nom
denominator = 2 * tp + fp + fn + self.smooth_denom
dc = nominator / denominator
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return 1 - dc
class TopKLoss(torch.nn.CrossEntropyLoss):
def __init__(self,
topk: float,
loss_weight: float = 1.,
**kwargs,
):
"""
Uses topk percent of values to compute CE loss
(expects pre softmax logits!)
Args:
topk: percentage of all entries to use for loss computation
loss_weight: scalar to balance multiple losses
"""
if "reduction" in kwargs:
raise ValueError("Reduction is not supported in TopKLoss."
"This will always return the mean!")
super().__init__(
reduction="none",
**kwargs,
)
if topk < 0 or topk > 1:
raise ValueError("topk needs to be in the range [0, 1].")
self.topk = topk
self.loss_weight = loss_weight
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Compute CE loss and uses mean of topk percent of the entries
Args:
input: logits for all foreground classes [N, C, *]
target: target classes. 0 is treated as background, >0 are
treated as foreground classes. [N, *]
Returns:
Tensor: final loss
"""
losses = super().forward(input, target)
k = int(losses.numel() * self.topk)
return self.loss_weight * losses.view(-1).topk(k=k, sorted=False)[0].mean()
class TopKLossSigmoid(torch.nn.BCEWithLogitsLoss):
def __init__(self,
num_classes: int,
topk: float,
smoothing: float = 0.0,
loss_weight: float = 1.,
**kwargs,
):
"""
Uses topk percent of values to compute BCE loss with one hot
(support multi class through one hot, expects pre sigmoid logits!)
Args:
num_classes: number of classes
topk: percentage of all entries to use for loss computation
smoothing: label smoothing
loss_weight: scalar to balance multiple losses
"""
if "reduction" in kwargs:
raise ValueError("Reduction is not supported in TopKLoss."
"This will always return the mean!")
super().__init__(
reduction="none",
**kwargs,
)
self.smoothing = smoothing
if smoothing > 0:
logger.info(f"Running label smoothing with smoothing: {smoothing}")
self.num_classes = num_classes
self.topk = topk
self.loss_weight = loss_weight
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Compute BCE loss based on one hot encoding of foreground(!) classes
and uses mean of topk percent of the entries
Args:
input: logits for all foreground(!) classes [N, C, *]
target: target classes [N, *]. Targets will be encoded with one
hot and 0 is treated as the background class and removed.
Returns:
Tensor: final loss
"""
target_one_hot = one_hot_smooth_batch(
target, num_classes=self.num_classes + 1, smoothing=self.smoothing) # [N, C + 1]
target_one_hot = target_one_hot[:, 1:] # background is implicitly encoded
losses = super().forward(input, target_one_hot.float())
k = int(losses.numel() * self.topk)
return self.loss_weight * losses.view(-1).topk(k=k, sorted=False)[0].mean()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment