Commit 78a6ef02 authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

pytorch update: no need to rewrap variable in backward()

parent 866b27d5
...@@ -41,7 +41,10 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -41,7 +41,10 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None grad_input = ctx.grad_input
if not isinstance(grad_input, torch.autograd.Variable):
grad_input = utils.volatile_variable(grad_input)
return grad_input * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy') @register_criterion('label_smoothed_cross_entropy')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment