Commit 2db01991 authored by mohammad's avatar mohammad
Browse files

fixed the interation count

parent 9787e6c9
......@@ -310,7 +310,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan'
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
......@@ -319,7 +320,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
total_loss_dict[got_nan_key] = is_nan
got_nan = got_nan or is_nan
total_loss_dict[got_nan_key] = total_loss_dict.get(
got_nan_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
......@@ -373,6 +377,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key])
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment