Commit cf2f4d9d authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'code_reuse' into 'master'

refactored for code reuse

See merge request ADLR/megatron-lm!11
parents beb3e0d3 cbd8c054
...@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider,
function add `batch generator` to the timers class. function add `batch generator` to the timers class.
""" """
# Timer.
timers = Timers()
# Arguments. # Arguments.
args = get_args() args = get_args()
# Timer.
timers = Timers()
# Tensorboard writer # Tensorboard writer
writer = get_tensorboard_writer(args) writer = get_tensorboard_writer(args)
# Pytorch distributed. # Initalize.
initialize_distributed(args) initialize_megatron(top_level_message, args, writer)
if torch.distributed.get_rank() == 0:
print(top_level_message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff. # Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args) train_data, val_data, test_data = train_val_test_data_provider(args)
...@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider,
args, None, 0, timers, True) args, None, 0, timers, True)
def initialize_megatron(message, args, writer):
""""Initialize distributed, random seed, and autoresume."""
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
def get_model(model_provider_func, args): def get_model(model_provider_func, args):
"""Build the model.""" """Build the model."""
...@@ -301,53 +308,31 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, ...@@ -301,53 +308,31 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
return loss_reduced, skipped_iter return loss_reduced, skipped_iter
def train(forward_step_func, model, optimizer, lr_scheduler, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
train_data_iterator, val_data_iterator, timers, args, writer): loss_scale, report_memory_flag, writer, args, timers):
"""Train the model function."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
skipped_iters += skipped_iter
iteration += 1
# Update losses. # Update losses.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging. # Logging.
if args.DDP_impl == 'torch': timers_to_log = []
timers_to_log = ['forward', 'backward', 'optimizer', def add_to_logging(name):
'batch generator'] if name in timers.timers:
else: timers_to_log.append(name)
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer', add_to_logging('forward')
'batch generator'] add_to_logging('backward')
add_to_logging('allreduce')
learning_rate = optimizer.param_groups[0]['lr'] add_to_logging('optimizer')
add_to_logging('batch generator')
# Tensorboard values.
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration) writer.add_scalar('learning_rate', learning_rate, iteration)
for key in total_loss_dict: for key in loss_dict:
writer.add_scalar(key, total_loss_dict[key], iteration) writer.add_scalar(key, loss_dict[key], iteration)
if args.fp16: if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration) writer.add_scalar('loss_scale', loss_scale, iteration)
normalizer = iteration % args.log_interval normalizer = iteration % args.log_interval
if normalizer == 0: if normalizer == 0:
normalizer = args.log_interval normalizer = args.log_interval
...@@ -369,14 +354,50 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -369,14 +354,50 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0 total_loss_dict[key] = 0.0
if args.fp16: if args.fp16:
log_string += ' loss scale: {:.1f} |'.format( log_string += ' loss scale: {:.1f} |'.format(loss_scale)
optimizer.loss_scale)
print_rank_0(log_string) print_rank_0(log_string)
if report_memory_flag: if report_memory_flag:
report_memory('after {} iterations'.format(iteration)) report_memory('after {} iterations'.format(iteration))
report_memory_flag = False report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval) timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model function."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
skipped_iters += skipped_iter
iteration += 1
# Logging.
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
report_memory_flag, writer, args,
timers)
# Autoresume # Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \ if (iteration % args.adlr_autoresume_interval == 0) and \
args.adlr_autoresume: args.adlr_autoresume:
......
...@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
def reduce_losses(losses):
reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses
def get_tensorboard_writer(args): def get_tensorboard_writer(args):
writer = None writer = None
if args.tensorboard_dir and args.rank == 0: if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard_dir) writer = SummaryWriter(log_dir=args.tensorboard_dir)
......
...@@ -22,6 +22,7 @@ from configure_data import configure_data ...@@ -22,6 +22,7 @@ from configure_data import configure_data
from megatron import mpu from megatron import mpu
from megatron.model import BertModel from megatron.model import BertModel
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
...@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers): ...@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers):
loss = lm_loss + nsp_loss loss = lm_loss + nsp_loss
reduced_losses = torch.cat((lm_loss.clone().detach().view(1), reduced_losses = reduce_losses([lm_loss, nsp_loss])
nsp_loss.clone().detach().view(1)))
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
return loss, {'lm loss': lm_loss_reduced, 'nsp loss': nsp_loss_reduced} return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]}
def get_train_val_test_data(args): def get_train_val_test_data(args):
......
...@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders ...@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu from megatron import mpu
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
...@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers): ...@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers):
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging. # Reduce loss for logging.
reduced_loss = loss.clone().detach().view(1) reduced_loss = reduce_losses([loss])
torch.distributed.all_reduce(reduced_loss)
reduced_loss = reduced_loss / torch.distributed.get_world_size()
return loss, {'lm loss': reduced_loss} return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(args): def get_train_val_test_data(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment