Commit abd7636e authored by mohammad's avatar mohammad
Browse files

nan values report added

parent a74f8d9c
...@@ -309,9 +309,17 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -309,9 +309,17 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
skipped_iters_key = 'skipped iterations' skipped_iters_key = 'skipped iterations'
total_loss_dict[skipped_iters_key] = total_loss_dict.get( total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan'
for key in loss_dict: for key in loss_dict:
if not skipped_iter: if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
total_loss_dict[got_nan_key] = is_nan
# Logging. # Logging.
timers_to_log = [] timers_to_log = []
...@@ -354,7 +362,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -354,7 +362,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
num_iterations = max( num_iterations = max(
1, args.log_interval - total_loss_dict[skipped_iters_key]) 1, args.log_interval - total_loss_dict[skipped_iters_key])
for key in total_loss_dict: for key in total_loss_dict:
if key != skipped_iters_key: if key not in [skipped_iters_key, got_nan_key]:
avg = total_loss_dict[key].item() / float(num_iterations) avg = total_loss_dict[key].item() / float(num_iterations)
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0 total_loss_dict[key] = 0.0
...@@ -362,6 +370,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -362,6 +370,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key]) total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key])
total_loss_dict[skipped_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0
print_rank_0(log_string) print_rank_0(log_string)
if report_memory_flag: if report_memory_flag:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment