Commit 4bf923d5 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'log_report_fix' into 'main'

fixed logging bug

See merge request ADLR/megatron-lm!134
parents 2bd27e20 2db01991
......@@ -93,9 +93,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = 0
if args.do_train and args.train_iters > 0:
iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
if args.do_valid:
prefix = 'the end of training for val data'
......@@ -299,15 +299,31 @@ def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag):
loss_scale, report_memory_flag, skipped_iter):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Update losses.
skipped_iters_key = 'skipped iterations'
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan'
got_nan = False
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[got_nan_key] = total_loss_dict.get(
got_nan_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
......@@ -347,12 +363,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
num_iterations = max(
1, args.log_interval - total_loss_dict[skipped_iters_key])
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if key not in [skipped_iters_key, got_nan_key]:
avg = total_loss_dict[key].item() / float(num_iterations)
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key])
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
......@@ -376,7 +401,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
......@@ -386,7 +410,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
model,
optimizer,
lr_scheduler)
skipped_iters += skipped_iter
iteration += 1
# Logging.
......@@ -396,7 +419,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag)
report_memory_flag, skipped_iter)
# Autoresume
if args.adlr_autoresume and \
......@@ -425,7 +448,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
'iteration {}'.format(rank, time_str, iteration))
sys.exit()
return iteration, skipped_iters
return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment