loss.py 660 Bytes
Newer Older
sunxx1's avatar
sunxx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F


class LabelSmoothLoss(_Loss):

    def __init__(self, smooth_ratio, num_classes):
        super(LabelSmoothLoss, self).__init__()
        self.smooth_ratio = smooth_ratio
        self.v = self.smooth_ratio / num_classes

    def forward(self, input, label):
        one_hot = torch.zeros_like(input)
        one_hot.fill_(self.v)
        y = label.to(torch.long).view(-1, 1)
        one_hot.scatter_(1, y, 1 - self.smooth_ratio + self.v)

        loss = -torch.sum(F.log_softmax(input, 1) *
                          (one_hot.detach())) / input.size(0)
        return loss