Commit b81cad66 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Fix TensorBoard writes

parent 5a304ede
...@@ -131,7 +131,7 @@ def _set_tensorboard_writer(args): ...@@ -131,7 +131,7 @@ def _set_tensorboard_writer(args):
'tensorboard writer') 'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \ if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0: args.tensorboard_dir and args.rank == (args.world_size -1):
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...') print('> setting tensorboard ...')
...@@ -242,7 +242,7 @@ class Timers: ...@@ -242,7 +242,7 @@ class Timers:
assert normalizer > 0.0 assert normalizer > 0.0
for name in names: for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration) writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True): def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers.""" """Log a group of timers."""
...@@ -253,7 +253,8 @@ class Timers: ...@@ -253,7 +253,8 @@ class Timers:
reset=reset) * 1000.0 / normalizer reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time) string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True) print(string, flush=True)
else: else:
print(string, flush=True) print(string, flush=True)
...@@ -31,6 +31,7 @@ from megatron import get_timers ...@@ -31,6 +31,7 @@ from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size from megatron import get_current_global_batch_size
from megatron import get_num_microbatches from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches from megatron import update_num_microbatches
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -675,12 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -675,12 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers = get_timers() timers = get_timers()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
# Update losses. # Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations' skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get( total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan' # Update losses and set nan iterations
got_nan = False got_nan = False
for key in loss_dict: for key in loss_dict:
if not skipped_iter: if not skipped_iter:
...@@ -692,9 +702,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -692,9 +702,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
value == -float('inf') or \ value == -float('inf') or \
value != value value != value
got_nan = got_nan or is_nan got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
total_loss_dict[got_nan_key] = total_loss_dict.get( nan_iters_key, 0) + int(got_nan)
got_nan_key, 0) + int(got_nan)
# Logging. # Logging.
timers_to_log = [] timers_to_log = []
...@@ -715,51 +724,53 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -715,51 +724,53 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad') add_to_logging('backward-clip-grad')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch generator') add_to_logging('batch-generator')
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \ batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches() get_num_microbatches()
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values. # Tensorboard values.
if writer and torch.distributed.get_rank() == 0: if writer and is_last_rank():
writer.add_scalar('learning_rate-iterations', learning_rate, iteration) writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning_rate-samples', learning_rate, writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples) args.consumed_train_samples)
writer.add_scalar('batch_size-iterations', batch_size, iteration) writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch_size-samples', batch_size, writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples) args.consumed_train_samples)
for key in loss_dict: for key in loss_dict:
writer.add_scalar(key + '-iterations', loss_dict[key], iteration) writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + '-samples', loss_dict[key], writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples) args.consumed_train_samples)
if args.fp16: if args.fp16:
writer.add_scalar('loss_scale-iterations', loss_scale, iteration) writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss_scale-samples', loss_scale, writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples) args.consumed_train_samples)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=normalizer) normalizer=total_iterations)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed() elapsed_time = timers('interval time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time', writer.add_scalar('iteration-time',
elapsed_time / args.log_interval, iteration) elapsed_time_per_iteration, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format( log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters) iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format( log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples) args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval) elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' learning rate: {:.3E} |'.format(learning_rate)
log_string += ' global batch size: {:6d} |'.format(batch_size) log_string += ' global batch size: {:5d} |'.format(batch_size)
num_iterations = max(
1, args.log_interval - total_loss_dict[skipped_iters_key])
for key in total_loss_dict: for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]: if key not in [advanced_iters_key, skipped_iters_key,
avg = total_loss_dict[key].item() / float(num_iterations) nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0: if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
...@@ -768,9 +779,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -768,9 +779,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key]) total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format( log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key]) total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0 total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string) print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.: if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized. # Report memory after optimizer state has been initialized.
......
...@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
......
...@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator) data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
......
...@@ -79,10 +79,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -79,10 +79,10 @@ def forward_step(data_iterator, model, input_tensor):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model. # Forward model.
......
...@@ -50,13 +50,13 @@ def _cross_entropy_forward_step(batch, model, input_tensor): ...@@ -50,13 +50,13 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
batch_ = batch batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_) tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model. # Forward model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment