Commit 6f6cb4ab authored by Myle Ott's avatar Myle Ott
Browse files

Add reduce kwarg to criterions

parent dcbf5e75
...@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict): def __init__(self, args, dst_dict):
super().__init__(args, dst_dict) super().__init__(args, dst_dict)
def forward(self, model, sample): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
...@@ -28,10 +28,11 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -28,10 +28,11 @@ class CrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1)) input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx) loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0] if reduce else loss.data,
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
......
...@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss): ...@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss):
self.args = args self.args = args
self.padding_idx = dst_dict.pad() self.padding_idx = dst_dict.pad()
def forward(self, model, sample): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
......
...@@ -17,7 +17,7 @@ from .fairseq_criterion import FairseqCriterion ...@@ -17,7 +17,7 @@ from .fairseq_criterion import FairseqCriterion
class LabelSmoothedNLLLoss(torch.autograd.Function): class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, target, eps, padding_idx, weights): def forward(ctx, input, target, eps, padding_idx, weights, reduce=True):
grad_input = input.new(input.size()).zero_() grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1) target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1) grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
...@@ -34,11 +34,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -34,11 +34,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
grad_input = grad_input.add(-eps / norm) grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input ctx.grad_input = grad_input
if reduce:
return input.new([grad_input.view(-1).dot(input.view(-1))]) return input.new([grad_input.view(-1).dot(input.view(-1))])
else:
return grad_input * input
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...@@ -48,7 +51,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -48,7 +51,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = args.label_smoothing self.eps = args.label_smoothing
self.weights = weights self.weights = weights
def forward(self, model, sample): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
...@@ -59,10 +62,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -59,10 +62,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)), dim=1) input = F.log_softmax(net_output.view(-1, net_output.size(-1)), dim=1)
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights) loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights, reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0] if reduce else loss.data,
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment