".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "c0815850b202cfd2c713d99a832917bac9d6b9d3"
Unverified Commit 6f0748d6 authored by Christian Sarofeen's avatar Christian Sarofeen Committed by GitHub
Browse files

Merge pull request #22 from romerojosh/master

Fixes to validation in imagenet example scripts.
parents 1d2094a1 cf45c54c
...@@ -377,13 +377,15 @@ def validate(val_loader, model, criterion): ...@@ -377,13 +377,15 @@ def validate(val_loader, model, criterion):
output = model(input_var) output = model(input_var)
loss = criterion(output, target_var) loss = criterion(output, target_var)
reduced_loss = reduce_tensor(loss.data)
# measure accuracy and record loss # measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
reduced_prec1 = reduce_tensor(prec1) if args.distributed:
reduced_prec5 = reduce_tensor(prec5) reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0)) losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0)) top1.update(to_python_float(prec1), input.size(0))
......
...@@ -338,13 +338,15 @@ def validate(val_loader, model, criterion): ...@@ -338,13 +338,15 @@ def validate(val_loader, model, criterion):
output = model(input_var) output = model(input_var)
loss = criterion(output, target_var) loss = criterion(output, target_var)
reduced_loss = reduce_tensor(loss.data)
# measure accuracy and record loss # measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
reduced_prec1 = reduce_tensor(prec1) if args.distributed:
reduced_prec5 = reduce_tensor(prec5) reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0)) losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0)) top1.update(to_python_float(prec1), input.size(0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment