focal_loss.py 1.03 KB
Newer Older
Jiangmiao Pang's avatar
Jiangmiao Pang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch.nn as nn
from mmdet.core import weighted_sigmoid_focal_loss

from ..registry import LOSSES


@LOSSES.register_module
class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=False,
                 loss_weight=1.0,
                 gamma=2.0,
                 alpha=0.25):
        super(FocalLoss, self).__init__()
lizz's avatar
lizz committed
16
        assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
Jiangmiao Pang's avatar
Jiangmiao Pang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        self.use_sigmoid = use_sigmoid
        self.loss_weight = loss_weight
        self.gamma = gamma
        self.alpha = alpha
        self.cls_criterion = weighted_sigmoid_focal_loss

    def forward(self, cls_score, label, label_weight, *args, **kwargs):
        if self.use_sigmoid:
            loss_cls = self.loss_weight * self.cls_criterion(
                cls_score,
                label,
                label_weight,
                gamma=self.gamma,
                alpha=self.alpha,
                *args,
                **kwargs)
        else:
            raise NotImplementedError
        return loss_cls