Commit 9ccebe5b authored by ptrblck's avatar ptrblck
Browse files

call .float() on GPU, remove unnecessary push to GPU

parent 28bdc04e
...@@ -175,7 +175,7 @@ def test(): ...@@ -175,7 +175,7 @@ def test():
output = model(data) output = model(data)
test_loss += to_python_float(F.nll_loss(output, target, size_average=False).data) # sum up batch loss test_loss += to_python_float(F.nll_loss(output, target, size_average=False).data) # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().float().sum() correct += pred.eq(target.data.view_as(pred)).float().cpu().sum()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
......
...@@ -371,8 +371,6 @@ def validate(val_loader, model, criterion): ...@@ -371,8 +371,6 @@ def validate(val_loader, model, criterion):
while input is not None: while input is not None:
i += 1 i += 1
target = target.cuda(async=True)
# compute output # compute output
with torch.no_grad(): with torch.no_grad():
output = model(input) output = model(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment