Commit 72c5f666 authored by Jared Casper's avatar Jared Casper
Browse files

Use DDP command line argument instead of source flag in pretrain_bert.py.

Note that there is currently an issue with bert using Torch DDP.
parent a54978bb
......@@ -15,9 +15,6 @@
"""Pretrain BERT"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime
import os
import random
......@@ -33,10 +30,7 @@ from learning_rates import AnnealingLR
from model import BertModel
from model import get_params_for_weight_decay_optimization
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
from model import DistributedDataParallel as LocalDDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
......@@ -78,12 +72,18 @@ def get_model(args):
_module.float()
# Wrap model for distributed training.
if USE_TORCH_DDP:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
model = DDP(model)
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
......@@ -92,7 +92,7 @@ def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)):
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
layers = model.model.bert.encoder.layer
pooler = model.model.bert.pooler
......@@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers):
return lm_loss, nsp_loss
def backward_step(optimizer, model, lm_loss, nsp_loss, args):
def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers):
"""Backward step."""
# Total loss.
......@@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP:
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
......@@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss,
nsp_loss, args)
nsp_loss, args, timers)
timers('backward').stop()
# Update parameters.
......@@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler,
# Logging.
timers_to_log = ['forward', 'backward', 'optimizer',
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
......@@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers)
# Reduce across processes.
if isinstance(model, DDP):
if isinstance(model, args.DDP_type):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data/args.world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment