Commit cf45c54c authored by Josh Romero's avatar Josh Romero
Browse files

Fixes to validation in imagenet example scripts. Precision and loss reporting...

Fixes to validation in imagenet example scripts. Precision and loss reporting modified to be consistent with train.
parent 21c229e0
...@@ -377,13 +377,15 @@ def validate(val_loader, model, criterion): ...@@ -377,13 +377,15 @@ def validate(val_loader, model, criterion):
output = model(input_var) output = model(input_var)
loss = criterion(output, target_var) loss = criterion(output, target_var)
reduced_loss = reduce_tensor(loss.data)
# measure accuracy and record loss # measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
reduced_prec1 = reduce_tensor(prec1) if args.distributed:
reduced_prec5 = reduce_tensor(prec5) reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0)) losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0)) top1.update(to_python_float(prec1), input.size(0))
......
...@@ -338,13 +338,15 @@ def validate(val_loader, model, criterion): ...@@ -338,13 +338,15 @@ def validate(val_loader, model, criterion):
output = model(input_var) output = model(input_var)
loss = criterion(output, target_var) loss = criterion(output, target_var)
reduced_loss = reduce_tensor(loss.data)
# measure accuracy and record loss # measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
reduced_prec1 = reduce_tensor(prec1) if args.distributed:
reduced_prec5 = reduce_tensor(prec5) reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0)) losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0)) top1.update(to_python_float(prec1), input.size(0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment