# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the BSD 3-Clause License (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://opensource.org/licenses/BSD-3-Clause # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn class LabelSmoothing(nn.Module): """ NLL loss with label smoothing. """ def __init__(self, smoothing=0.0): """ Constructor for the LabelSmoothing module. :param smoothing: label smoothing factor """ super(LabelSmoothing, self).__init__() self.confidence = 1.0 - smoothing self.smoothing = smoothing def forward(self, x, target): logprobs = torch.nn.functional.log_softmax(x, dim=-1) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) smooth_loss = -logprobs.mean(dim=-1) loss = self.confidence * nll_loss + self.smoothing * smooth_loss return loss.mean()