"docs/vscode:/vscode.git/clone" did not exist on "8ec6b873b24257a8d6c67df9213566fbbc738ebc"
Commit 4947002d authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'torchddp' into 'master'

Use DDP command line argument instead of source flag in pretrain_bert.py.

See merge request ADLR/megatron-lm!2
parents a54978bb 72c5f666
......@@ -15,9 +15,6 @@
"""Pretrain BERT"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime
import os
import random
......@@ -33,10 +30,7 @@ from learning_rates import AnnealingLR
from model import BertModel
from model import get_params_for_weight_decay_optimization
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
from model import DistributedDataParallel as LocalDDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
......@@ -78,12 +72,18 @@ def get_model(args):
_module.float()
# Wrap model for distributed training.
if USE_TORCH_DDP:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
model = DDP(model)
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
......@@ -92,7 +92,7 @@ def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)):
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
layers = model.model.bert.encoder.layer
pooler = model.model.bert.pooler
......@@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers):
return lm_loss, nsp_loss
def backward_step(optimizer, model, lm_loss, nsp_loss, args):
def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers):
"""Backward step."""
# Total loss.
......@@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP:
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
......@@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss,
nsp_loss, args)
nsp_loss, args, timers)
timers('backward').stop()
# Update parameters.
......@@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler,
# Logging.
timers_to_log = ['forward', 'backward', 'optimizer',
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
......@@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers)
# Reduce across processes.
if isinstance(model, DDP):
if isinstance(model, args.DDP_type):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data/args.world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment