focal_loss.py 2.21 KB
Newer Older
1
2
3
import torch
import torch.nn.functional as F

limm's avatar
limm committed
4
5
from ..utils import _log_api_usage_once

6
7

def sigmoid_focal_loss(
Aditya Oke's avatar
Aditya Oke committed
8
9
    inputs: torch.Tensor,
    targets: torch.Tensor,
10
11
12
    alpha: float = 0.25,
    gamma: float = 2,
    reduction: str = "none",
limm's avatar
limm committed
13
) -> torch.Tensor:
14
15
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Aditya Oke's avatar
Aditya Oke committed
16

17
    Args:
limm's avatar
limm committed
18
        inputs (Tensor): A float tensor of arbitrary shape.
19
                The predictions for each example.
limm's avatar
limm committed
20
        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
Aditya Oke's avatar
Aditya Oke committed
21
                classification label for each element in inputs
22
                (0 for the negative class and 1 for the positive class).
limm's avatar
limm committed
23
24
25
26
27
28
29
30
        alpha (float): Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default: ``0.25``.
        gamma (float): Exponent of the modulating factor (1 - p_t) to
                balance easy vs hard examples. Default: ``2``.
        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
                ``'none'``: No reduction will be applied to the output.
                ``'mean'``: The output will be averaged.
                ``'sum'``: The output will be summed. Default: ``'none'``.
31
32
33
    Returns:
        Loss tensor with the reduction option applied.
    """
limm's avatar
limm committed
34
35
36
37
    # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py

    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(sigmoid_focal_loss)
38
    p = torch.sigmoid(inputs)
limm's avatar
limm committed
39
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
40
41
42
43
44
45
46
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

limm's avatar
limm committed
47
48
49
50
    # Check reduction option and return loss accordingly
    if reduction == "none":
        pass
    elif reduction == "mean":
51
52
53
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()
limm's avatar
limm committed
54
55
56
57
    else:
        raise ValueError(
            f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
        )
58
    return loss