focal_loss.py 2.72 KB
Newer Older
Jiangmiao Pang's avatar
Jiangmiao Pang committed
1
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
2
import torch.nn.functional as F
Jiangmiao Pang's avatar
Jiangmiao Pang committed
3

Kai Chen's avatar
Kai Chen committed
4
5
from mmdet.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from .utils import weight_reduce_loss
Jiangmiao Pang's avatar
Jiangmiao Pang committed
6
7
8
from ..registry import LOSSES


Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


def sigmoid_focal_loss(pred,
                       target,
                       weight=None,
                       gamma=2.0,
                       alpha=0.25,
                       reduction='mean',
                       avg_factor=None):
    # Function.apply does not accept keyword arguments, so the decorator
    # "weighted_loss" is not applicable
    loss = _sigmoid_focal_loss(pred, target, gamma, alpha)
    # TODO: find a proper way to handle the shape of weight
    if weight is not None:
        weight = weight.view(-1, 1)
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


Jiangmiao Pang's avatar
Jiangmiao Pang committed
45
46
47
48
@LOSSES.register_module
class FocalLoss(nn.Module):

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
49
                 use_sigmoid=True,
Jiangmiao Pang's avatar
Jiangmiao Pang committed
50
                 gamma=2.0,
Kai Chen's avatar
Kai Chen committed
51
52
53
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0):
Jiangmiao Pang's avatar
Jiangmiao Pang committed
54
        super(FocalLoss, self).__init__()
lizz's avatar
lizz committed
55
        assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
Jiangmiao Pang's avatar
Jiangmiao Pang committed
56
57
58
        self.use_sigmoid = use_sigmoid
        self.gamma = gamma
        self.alpha = alpha
Kai Chen's avatar
Kai Chen committed
59
60
        self.reduction = reduction
        self.loss_weight = loss_weight
Jiangmiao Pang's avatar
Jiangmiao Pang committed
61

62
63
64
65
66
67
68
69
70
    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
71
        if self.use_sigmoid:
Kai Chen's avatar
Kai Chen committed
72
73
74
75
            loss_cls = self.loss_weight * sigmoid_focal_loss(
                pred,
                target,
                weight,
Jiangmiao Pang's avatar
Jiangmiao Pang committed
76
77
                gamma=self.gamma,
                alpha=self.alpha,
78
                reduction=reduction,
Kai Chen's avatar
Kai Chen committed
79
                avg_factor=avg_factor)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
80
81
82
        else:
            raise NotImplementedError
        return loss_cls