Commit 4bf923d5 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'log_report_fix' into 'main'

fixed logging bug

See merge request ADLR/megatron-lm!134
parents 2bd27e20 2db01991
...@@ -93,7 +93,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -93,7 +93,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration, _ = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator)
...@@ -299,15 +299,31 @@ def train_step(forward_step_func, data_iterator, ...@@ -299,15 +299,31 @@ def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag): loss_scale, report_memory_flag, skipped_iter):
"""Log training information such as losses, timing, ....""" """Log training information such as losses, timing, ...."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
# Update losses. # Update losses.
skipped_iters_key = 'skipped iterations'
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan'
got_nan = False
for key in loss_dict: for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[got_nan_key] = total_loss_dict.get(
got_nan_key, 0) + int(got_nan)
# Logging. # Logging.
timers_to_log = [] timers_to_log = []
...@@ -347,12 +363,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -347,12 +363,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval) elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' learning rate: {:.3E} |'.format(learning_rate)
num_iterations = max(
1, args.log_interval - total_loss_dict[skipped_iters_key])
for key in total_loss_dict: for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval if key not in [skipped_iters_key, got_nan_key]:
avg = total_loss_dict[key].item() / float(num_iterations)
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0 total_loss_dict[key] = 0.0
if args.fp16: if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key])
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
print_rank_0(log_string) print_rank_0(log_string)
if report_memory_flag: if report_memory_flag:
report_memory('after {} iterations'.format(iteration)) report_memory('after {} iterations'.format(iteration))
...@@ -376,7 +401,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -376,7 +401,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
skipped_iters = 0
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
...@@ -386,7 +410,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -386,7 +410,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
model, model,
optimizer, optimizer,
lr_scheduler) lr_scheduler)
skipped_iters += skipped_iter
iteration += 1 iteration += 1
# Logging. # Logging.
...@@ -396,7 +419,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -396,7 +419,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
report_memory_flag) report_memory_flag, skipped_iter)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
...@@ -425,7 +448,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -425,7 +448,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
'iteration {}'.format(rank, time_str, iteration)) 'iteration {}'.format(rank, time_str, iteration))
sys.exit() sys.exit()
return iteration, skipped_iters return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False): def evaluate(forward_step_func, data_iterator, model, verbose=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment