Commit e7094b14 authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

Fix LabelSmoothedCrossEntropy test

parent 78a6ef02
......@@ -7,6 +7,7 @@
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from fairseq import utils
......@@ -24,6 +25,8 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
norm = grad_input.size(-1)
if weights is not None:
if isinstance(grad_input, Variable) and not isinstance(weights, Variable):
weights = Variable(weights, requires_grad=False)
norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment