"docs/source/vscode:/vscode.git/clone" did not exist on "81e6345ddf893c594f6d76406a388fa012cb0a29"
Commit 72c5f666 authored by Jared Casper's avatar Jared Casper
Browse files

Use DDP command line argument instead of source flag in pretrain_bert.py.

Note that there is currently an issue with bert using Torch DDP.
parent a54978bb
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
"""Pretrain BERT""" """Pretrain BERT"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime from datetime import datetime
import os import os
import random import random
...@@ -33,10 +30,7 @@ from learning_rates import AnnealingLR ...@@ -33,10 +30,7 @@ from learning_rates import AnnealingLR
from model import BertModel from model import BertModel
from model import get_params_for_weight_decay_optimization from model import get_params_for_weight_decay_optimization
from model import gpt2_get_params_for_weight_decay_optimization from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP: from model import DistributedDataParallel as LocalDDP
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
import mpu import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from utils import Timers from utils import Timers
...@@ -78,12 +72,18 @@ def get_model(args): ...@@ -78,12 +72,18 @@ def get_model(args):
_module.float() _module.float()
# Wrap model for distributed training. # Wrap model for distributed training.
if USE_TORCH_DDP: if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i, args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
process_group=mpu.get_data_parallel_group()) model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else: else:
model = DDP(model) print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model return model
...@@ -92,7 +92,7 @@ def get_optimizer(model, args): ...@@ -92,7 +92,7 @@ def get_optimizer(model, args):
"""Set up the optimizer.""" """Set up the optimizer."""
# Build parameter groups (weight decay and non-decay). # Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)): while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module model = model.module
layers = model.model.bert.encoder.layer layers = model.model.bert.encoder.layer
pooler = model.model.bert.pooler pooler = model.model.bert.pooler
...@@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers): ...@@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers):
return lm_loss, nsp_loss return lm_loss, nsp_loss
def backward_step(optimizer, model, lm_loss, nsp_loss, args): def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers):
"""Backward step.""" """Backward step."""
# Total loss. # Total loss.
...@@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args): ...@@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data) torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP: if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
lm_loss_reduced = reduced_losses[0] lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1] nsp_loss_reduced = reduced_losses[1]
...@@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, ...@@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# Calculate gradients, reduce across processes, and clip. # Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss, lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss,
nsp_loss, args) nsp_loss, args, timers)
timers('backward').stop() timers('backward').stop()
# Update parameters. # Update parameters.
...@@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler, ...@@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler,
# Logging. # Logging.
timers_to_log = ['forward', 'backward', 'optimizer', if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader'] 'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr'] learning_rate = optimizer.param_groups[0]['lr']
...@@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False): ...@@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
lm_loss, nsp_loss = forward_step(data_iterator, model, lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers) args, timers)
# Reduce across processes. # Reduce across processes.
if isinstance(model, DDP): if isinstance(model, args.DDP_type):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data) torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data/args.world_size reduced_losses.data = reduced_losses.data/args.world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment