focal_loss.py 2.18 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# Below code are based on
# https://zhuanlan.zhihu.com/p/28527749

9

10
11
class FocalLoss(nn.Module):
    r"""
12
13
    This criterion is a implemenation of Focal Loss, which is proposed in
    Focal Loss for Dense Object Detection.
14

15
        Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
16

17
    The losses are averaged across observations for each minibatch.
18

19
20
21
22
23
24
25
    Args:
        alpha(1D Tensor, Variable) : the scalar factor for this criterion
        gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
                               putting more focus on hard, misclassified examples
        size_average(bool): By default, the losses are averaged over observations for each minibatch.
                            However, if the field size_average is set to False, the losses are
                            instead summed for each minibatch.
26
27
28


    """
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
51
        class_mask.scatter_(1, ids.data, 1.0)
52
53
54
55
56

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

57
        probs = (P * class_mask).sum(1).view(-1, 1)
58
59
60

        log_p = probs.log()

61
        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
62
63
64
65
66
67

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss