Commit cf2f4d9d authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'code_reuse' into 'master'

refactored for code reuse

See merge request ADLR/megatron-lm!11
parents beb3e0d3 cbd8c054
......@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider,
function add `batch generator` to the timers class.
"""
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Timer.
timers = Timers()
# Tensorboard writer
writer = get_tensorboard_writer(args)
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(top_level_message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Initalize.
initialize_megatron(top_level_message, args, writer)
# Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args)
......@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider,
args, None, 0, timers, True)
def initialize_megatron(message, args, writer):
""""Initialize distributed, random seed, and autoresume."""
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
def get_model(model_provider_func, args):
"""Build the model."""
......@@ -301,6 +308,62 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
return loss_reduced, skipped_iter
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, writer, args, timers):
# Update losses.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward')
add_to_logging('backward')
add_to_logging('allreduce')
add_to_logging('optimizer')
add_to_logging('batch generator')
# Tensorboard values.
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
if args.fp16:
writer.add_scalar('loss_scale', loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model function."""
......@@ -328,54 +391,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
skipped_iters += skipped_iter
iteration += 1
# Update losses.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging.
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator']
learning_rate = optimizer.param_groups[0]['lr']
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
for key in total_loss_dict:
writer.add_scalar(key, total_loss_dict[key], iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(
optimizer.loss_scale)
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
report_memory_flag, writer, args,
timers)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \
......
......@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
def reduce_losses(losses):
reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses
def get_tensorboard_writer(args):
writer = None
if args.tensorboard_dir and args.rank == 0:
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard_dir)
......
......@@ -22,6 +22,7 @@ from configure_data import configure_data
from megatron import mpu
from megatron.model import BertModel
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
......@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers):
loss = lm_loss + nsp_loss
reduced_losses = torch.cat((lm_loss.clone().detach().view(1),
nsp_loss.clone().detach().view(1)))
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
reduced_losses = reduce_losses([lm_loss, nsp_loss])
return loss, {'lm loss': lm_loss_reduced, 'nsp loss': nsp_loss_reduced}
return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]}
def get_train_val_test_data(args):
......
......@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu
from megatron.model import GPT2Model
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
......@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers):
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
reduced_loss = loss.clone().detach().view(1)
torch.distributed.all_reduce(reduced_loss)
reduced_loss = reduced_loss / torch.distributed.get_world_size()
reduced_loss = reduce_losses([loss])
return loss, {'lm loss': reduced_loss}
return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment