Commit e1f49695 authored by Myle Ott's avatar Myle Ott
Browse files

Rename LabelSmoothedCrossEntropy to LabelSmoothedNLLLoss

parent b1dfd39e
...@@ -14,7 +14,7 @@ import torch.nn.functional as F ...@@ -14,7 +14,7 @@ import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion from .fairseq_criterion import FairseqCriterion
class LabelSmoothedCrossEntropy(torch.autograd.Function): class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, target, eps, padding_idx, weights): def forward(ctx, input, target, eps, padding_idx, weights):
...@@ -59,7 +59,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -59,7 +59,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1))) input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights) loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0],
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import torch import torch
import unittest import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
...@@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase):
input = Variable(torch.randn(3, 5), requires_grad=True) input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4 idx = torch.rand(3) * 4
target = Variable(idx.long()) target = Variable(idx.long())
criterion = LabelSmoothedCrossEntropy() criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck( self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target) lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
)) ))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment