focal_loss.py 1.79 KB
Newer Older
1
2
3
import torch
import torch.nn.functional as F

4
5
from ..utils import _log_api_usage_once

6
7

def sigmoid_focal_loss(
Aditya Oke's avatar
Aditya Oke committed
8
9
    inputs: torch.Tensor,
    targets: torch.Tensor,
10
11
12
13
14
15
16
    alpha: float = 0.25,
    gamma: float = 2,
    reduction: str = "none",
):
    """
    Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Aditya Oke's avatar
Aditya Oke committed
17

18
    Args:
19
20
21
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
Aditya Oke's avatar
Aditya Oke committed
22
                classification label for each element in inputs
23
24
25
26
27
28
29
30
31
32
33
34
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default = 0.25
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
35
    _log_api_usage_once("torchvision.ops.sigmoid_focal_loss")
36
    p = torch.sigmoid(inputs)
37
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
38
39
40
41
42
43
44
45
46
47
48
49
50
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss