Commit b81cad66 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Fix TensorBoard writes

parent 5a304ede
......@@ -131,7 +131,7 @@ def _set_tensorboard_writer(args):
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
args.tensorboard_dir and args.rank == (args.world_size -1):
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
......@@ -242,7 +242,7 @@ class Timers:
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
......@@ -253,7 +253,8 @@ class Timers:
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
......@@ -31,6 +31,7 @@ from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
......@@ -675,12 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers = get_timers()
writer = get_tensorboard_writer()
# Update losses.
# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan'
# Update losses and set nan iterations
got_nan = False
for key in loss_dict:
if not skipped_iter:
......@@ -692,9 +702,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[got_nan_key] = total_loss_dict.get(
got_nan_key, 0) + int(got_nan)
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
......@@ -715,51 +724,53 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad')
add_to_logging('optimizer')
add_to_logging('batch generator')
add_to_logging('batch-generator')
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values.
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate-iterations', learning_rate, iteration)
writer.add_scalar('learning_rate-samples', learning_rate,
if writer and is_last_rank():
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
writer.add_scalar('batch_size-iterations', batch_size, iteration)
writer.add_scalar('batch_size-samples', batch_size,
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(key + '-iterations', loss_dict[key], iteration)
writer.add_scalar(key + '-samples', loss_dict[key],
writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if args.fp16:
writer.add_scalar('loss_scale-iterations', loss_scale, iteration)
writer.add_scalar('loss_scale-samples', loss_scale,
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
normalizer=total_iterations)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
log_string += ' global batch size: {:6d} |'.format(batch_size)
num_iterations = max(
1, args.log_interval - total_loss_dict[skipped_iters_key])
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]:
avg = total_loss_dict[key].item() / float(num_iterations)
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
......@@ -768,9 +779,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key])
total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized.
......
......@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
......
......@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch generator').stop()
timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
......
......@@ -79,10 +79,10 @@ def forward_step(data_iterator, model, input_tensor):
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop()
timers('batch-generator').stop()
# Forward model.
......
......@@ -50,13 +50,13 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop()
timers('batch-generator').stop()
# Forward model.
if mpu.is_pipeline_first_stage():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment