focal_loss.py 4.17 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
unknown's avatar
unknown committed
2
3
4
5
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
6
from .utils import convert_to_one_hot, weight_reduce_loss
unknown's avatar
unknown committed
7
8
9
10
11
12
13
14
15


def sigmoid_focal_loss(pred,
                       target,
                       weight=None,
                       gamma=2.0,
                       alpha=0.25,
                       reduction='mean',
                       avg_factor=None):
16
    r"""Sigmoid focal loss.
unknown's avatar
unknown committed
17
18

    Args:
19
        pred (torch.Tensor): The prediction with shape (N, \*).
unknown's avatar
unknown committed
20
        target (torch.Tensor): The ground truth label of the prediction with
21
            shape (N, \*).
unknown's avatar
unknown committed
22
        weight (torch.Tensor, optional): Sample-wise loss weight with shape
23
            (N, ). Defaults to None.
unknown's avatar
unknown committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        gamma (float): The gamma for calculating the modulating factor.
            Defaults to 2.0.
        alpha (float): A balanced form for Focal Loss. Defaults to 0.25.
        reduction (str): The method used to reduce the loss.
            Options are "none", "mean" and "sum". If reduction is 'none' ,
            loss is same shape as pred and label. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.

    Returns:
        torch.Tensor: Loss.
    """
    assert pred.shape == \
        target.shape, 'pred and target should be in the same shape.'
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    if weight is not None:
        assert weight.dim() == 1
        weight = weight.float()
        if pred.dim() > 1:
            weight = weight.reshape(-1, 1)
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


@LOSSES.register_module()
class FocalLoss(nn.Module):
    """Focal loss.

    Args:
        gamma (float): Focusing parameter in focal loss.
            Defaults to 2.0.
        alpha (float): The parameter in balanced form of focal
            loss. Defaults to 0.25.
        reduction (str): The method used to reduce the loss into
            a scalar. Options are "none" and "mean". Defaults to 'mean'.
        loss_weight (float): Weight of loss. Defaults to 1.0.
    """

    def __init__(self,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0):

        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
86
        r"""Sigmoid focal loss.
unknown's avatar
unknown committed
87
88

        Args:
89
            pred (torch.Tensor): The prediction with shape (N, \*).
unknown's avatar
unknown committed
90
            target (torch.Tensor): The ground truth label of the prediction
91
                with shape (N, \*), N or (N,1).
unknown's avatar
unknown committed
92
            weight (torch.Tensor, optional): Sample-wise loss weight with shape
93
                (N, \*). Defaults to None.
unknown's avatar
unknown committed
94
            avg_factor (int, optional): Average factor that is used to average
95
                the loss. Defaults to None.
unknown's avatar
unknown committed
96
97
98
99
100
101
102
103
104
105
            reduction_override (str, optional): The method used to reduce the
                loss into a scalar. Options are "none", "mean" and "sum".
                Defaults to None.

        Returns:
            torch.Tensor: Loss.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
106
107
        if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1):
            target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1])
unknown's avatar
unknown committed
108
109
110
111
112
113
114
115
116
        loss_cls = self.loss_weight * sigmoid_focal_loss(
            pred,
            target,
            weight,
            gamma=self.gamma,
            alpha=self.alpha,
            reduction=reduction,
            avg_factor=avg_factor)
        return loss_cls