Commit 6f6cb4ab authored by Myle Ott's avatar Myle Ott
Browse files

Add reduce kwarg to criterions

parent dcbf5e75
......@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict):
super().__init__(args, dst_dict)
def forward(self, model, sample):
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
......@@ -28,10 +28,11 @@ class CrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0],
'loss': loss.data[0] if reduce else loss.data,
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......
......@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss):
self.args = args
self.padding_idx = dst_dict.pad()
def forward(self, model, sample):
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
......
......@@ -17,7 +17,7 @@ from .fairseq_criterion import FairseqCriterion
class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights):
def forward(ctx, input, target, eps, padding_idx, weights, reduce=True):
grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
......@@ -34,11 +34,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input
if reduce:
return input.new([grad_input.view(-1).dot(input.view(-1))])
else:
return grad_input * input
@staticmethod
def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
......@@ -48,7 +51,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = args.label_smoothing
self.weights = weights
def forward(self, model, sample):
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
......@@ -59,10 +62,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)), dim=1)
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights, reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0],
'loss': loss.data[0] if reduce else loss.data,
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment