Commit e1f49695 authored by Myle Ott's avatar Myle Ott
Browse files

Rename LabelSmoothedCrossEntropy to LabelSmoothedNLLLoss

parent b1dfd39e
......@@ -14,7 +14,7 @@ import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
class LabelSmoothedCrossEntropy(torch.autograd.Function):
class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights):
......@@ -59,7 +59,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0],
......
......@@ -8,7 +8,7 @@
import torch
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck
......@@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase):
input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4
target = Variable(idx.long())
criterion = LabelSmoothedCrossEntropy()
criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment