Commit 664cd28b authored by mohammad's avatar mohammad
Browse files

fixed loss average when all but one value is skipped

parent 79888e16
......@@ -315,7 +315,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
......@@ -369,8 +370,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]:
avg = total_loss_dict[key].item() / float(num_iterations)
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' number of skipped iterations: {:3d} |'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment