Commit 9ccebe5b authored by ptrblck's avatar ptrblck
Browse files

call .float() on GPU, remove unnecessary push to GPU

parent 28bdc04e
......@@ -175,7 +175,7 @@ def test():
output = model(data)
test_loss += to_python_float(F.nll_loss(output, target, size_average=False).data) # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().float().sum()
correct += pred.eq(target.data.view_as(pred)).float().cpu().sum()
test_loss /= len(test_loader.dataset)
......
......@@ -371,8 +371,6 @@ def validate(val_loader, model, criterion):
while input is not None:
i += 1
target = target.cuda(async=True)
# compute output
with torch.no_grad():
output = model(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment