Commit cf2f4d9d authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'code_reuse' into 'master'

refactored for code reuse

See merge request ADLR/megatron-lm!11
parents beb3e0d3 cbd8c054
...@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider,
function add `batch generator` to the timers class. function add `batch generator` to the timers class.
""" """
# Timer.
timers = Timers()
# Arguments. # Arguments.
args = get_args() args = get_args()
# Timer.
timers = Timers()
# Tensorboard writer # Tensorboard writer
writer = get_tensorboard_writer(args) writer = get_tensorboard_writer(args)
# Pytorch distributed. # Initalize.
initialize_distributed(args) initialize_megatron(top_level_message, args, writer)
if torch.distributed.get_rank() == 0:
print(top_level_message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff. # Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args) train_data, val_data, test_data = train_val_test_data_provider(args)
...@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider,
args, None, 0, timers, True) args, None, 0, timers, True)
def initialize_megatron(message, args, writer):
""""Initialize distributed, random seed, and autoresume."""
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
def get_model(model_provider_func, args): def get_model(model_provider_func, args):
"""Build the model.""" """Build the model."""
...@@ -301,6 +308,62 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, ...@@ -301,6 +308,62 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
return loss_reduced, skipped_iter return loss_reduced, skipped_iter
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, writer, args, timers):
# Update losses.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward')
add_to_logging('backward')
add_to_logging('allreduce')
add_to_logging('optimizer')
add_to_logging('batch generator')
# Tensorboard values.
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
if args.fp16:
writer.add_scalar('loss_scale', loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer): train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model function.""" """Train the model function."""
...@@ -328,54 +391,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -328,54 +391,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
skipped_iters += skipped_iter skipped_iters += skipped_iter
iteration += 1 iteration += 1
# Update losses.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging. # Logging.
if args.DDP_impl == 'torch': report_memory_flag = training_log(loss_dict, total_loss_dict,
timers_to_log = ['forward', 'backward', 'optimizer', optimizer.param_groups[0]['lr'],
'batch generator'] iteration, optimizer.loss_scale,
else: report_memory_flag, writer, args,
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer', timers)
'batch generator']
learning_rate = optimizer.param_groups[0]['lr']
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
for key in total_loss_dict:
writer.add_scalar(key, total_loss_dict[key], iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(
optimizer.loss_scale)
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume # Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \ if (iteration % args.adlr_autoresume_interval == 0) and \
......
...@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
def reduce_losses(losses):
reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses
def get_tensorboard_writer(args): def get_tensorboard_writer(args):
writer = None writer = None
if args.tensorboard_dir and args.rank == 0: if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard_dir) writer = SummaryWriter(log_dir=args.tensorboard_dir)
......
...@@ -22,6 +22,7 @@ from configure_data import configure_data ...@@ -22,6 +22,7 @@ from configure_data import configure_data
from megatron import mpu from megatron import mpu
from megatron.model import BertModel from megatron.model import BertModel
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
...@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers): ...@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers):
loss = lm_loss + nsp_loss loss = lm_loss + nsp_loss
reduced_losses = torch.cat((lm_loss.clone().detach().view(1), reduced_losses = reduce_losses([lm_loss, nsp_loss])
nsp_loss.clone().detach().view(1)))
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
return loss, {'lm loss': lm_loss_reduced, 'nsp loss': nsp_loss_reduced} return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]}
def get_train_val_test_data(args): def get_train_val_test_data(args):
......
...@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders ...@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu from megatron import mpu
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
...@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers): ...@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers):
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging. # Reduce loss for logging.
reduced_loss = loss.clone().detach().view(1) reduced_loss = reduce_losses([loss])
torch.distributed.all_reduce(reduced_loss)
reduced_loss = reduced_loss / torch.distributed.get_world_size()
return loss, {'lm loss': reduced_loss} return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(args): def get_train_val_test_data(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment