import torch from torch.nn.modules.loss import _Loss import torch.nn.functional as F class LabelSmoothLoss(_Loss): def __init__(self, smooth_ratio, num_classes): super(LabelSmoothLoss, self).__init__() self.smooth_ratio = smooth_ratio self.v = self.smooth_ratio / num_classes def forward(self, input, label): one_hot = torch.zeros_like(input) one_hot.fill_(self.v) y = label.to(torch.long).view(-1, 1) one_hot.scatter_(1, y, 1 - self.smooth_ratio + self.v) loss = -torch.sum(F.log_softmax(input, 1) * (one_hot.detach())) / input.size(0) return loss