Commit ffe53d6f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Create standalone label_smoothed_nll_loss

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/739

Differential Revision: D16377798

Pulled By: myleott

fbshipit-source-id: 20047c80de2e6f108269ace4ae3eec906a5920dd
parent c811e0e0
...@@ -12,6 +12,26 @@ from fairseq import utils ...@@ -12,6 +12,26 @@ from fairseq import utils
from . import FairseqCriterion, register_criterion from . import FairseqCriterion, register_criterion
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
non_pad_mask = target.ne(ignore_index)
nll_loss = nll_loss[non_pad_mask]
smooth_loss = smooth_loss[non_pad_mask]
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@register_criterion('label_smoothed_cross_entropy') @register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...@@ -51,17 +71,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -51,17 +71,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1) target = model.get_targets(sample, net_output).view(-1, 1)
non_pad_mask = target.ne(self.padding_idx) loss, nll_loss = label_smoothed_nll_loss(
if reduce: lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
nll_loss = -lprobs.gather(dim=-1, index=target).masked_fill_(1.0-non_pad_mask, 0.0) )
nll_loss = nll_loss.sum()
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).masked_fill_(1.0-non_pad_mask, 0.0)
smooth_loss = smooth_loss.sum()
else:
nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss return loss, nll_loss
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment