Commit b4b0d739 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'fix_logging' into 'main'

fixed loss average when all but one value is skipped

See merge request ADLR/megatron-lm!164
parents 79888e16 664cd28b
...@@ -315,7 +315,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -315,7 +315,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
got_nan = False got_nan = False
for key in loss_dict: for key in loss_dict:
if not skipped_iter: if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
else: else:
value = loss_dict[key].float().sum().item() value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \ is_nan = value == float('inf') or \
...@@ -369,8 +370,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -369,8 +370,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
for key in total_loss_dict: for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]: if key not in [skipped_iters_key, got_nan_key]:
avg = total_loss_dict[key].item() / float(num_iterations) avg = total_loss_dict[key].item() / float(num_iterations)
log_string += ' {}: {:.6E} |'.format(key, avg) if avg > 0.0:
total_loss_dict[key] = 0.0 log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
if args.fp16: if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment