focal_loss.py 2.21 KB
Newer Older
1
2
3
import torch
import torch.nn.functional as F

4
5
from ..utils import _log_api_usage_once

6
7

def sigmoid_focal_loss(
Aditya Oke's avatar
Aditya Oke committed
8
9
    inputs: torch.Tensor,
    targets: torch.Tensor,
10
11
12
    alpha: float = 0.25,
    gamma: float = 2,
    reduction: str = "none",
13
) -> torch.Tensor:
14
15
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Aditya Oke's avatar
Aditya Oke committed
16

17
    Args:
18
        inputs (Tensor): A float tensor of arbitrary shape.
19
                The predictions for each example.
20
        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
Aditya Oke's avatar
Aditya Oke committed
21
                classification label for each element in inputs
22
                (0 for the negative class and 1 for the positive class).
23
24
25
26
27
28
29
30
        alpha (float): Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default: ``0.25``.
        gamma (float): Exponent of the modulating factor (1 - p_t) to
                balance easy vs hard examples. Default: ``2``.
        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
                ``'none'``: No reduction will be applied to the output.
                ``'mean'``: The output will be averaged.
                ``'sum'``: The output will be summed. Default: ``'none'``.
31
32
33
    Returns:
        Loss tensor with the reduction option applied.
    """
34
    # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
35

Kai Zhang's avatar
Kai Zhang committed
36
37
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(sigmoid_focal_loss)
38
    p = torch.sigmoid(inputs)
39
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
40
41
42
43
44
45
46
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

47
48
49
50
    # Check reduction option and return loss accordingly
    if reduction == "none":
        pass
    elif reduction == "mean":
51
52
53
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()
54
55
56
57
    else:
        raise ValueError(
            f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
        )
58
    return loss