Commit dd889062 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Interleaved pipeline execution and code refactoring

- Split a model's computation into multiple virtual stages as needed,
and schedule communication correctly between these virtual stages
- Move schedule code into `schedules.py` and communication code into
`p2p_communication.py`
- Use hyphens instead of spaces in all time logging for consistency
- Factor out code in megatron/training.py into helper functions
- Refactor evaluate() function: make it use forward_backward_schedule
functions
parent c1faa9fe
...@@ -557,6 +557,8 @@ def _add_distributed_args(parser): ...@@ -557,6 +557,8 @@ def _add_distributed_args(parser):
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None,
help='Number of virtual pipeline stages in physical stage.')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
......
...@@ -111,8 +111,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,8 +111,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP): unwrapped_model = []
model = model.module for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
...@@ -124,7 +128,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -124,7 +128,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
...@@ -211,8 +220,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -211,8 +220,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP): unwrapped_model = []
model = model.module for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -297,7 +311,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -297,7 +311,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# Model. # Model.
model.load_state_dict(state_dict['model'], strict=strict) if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering # Fix up query/key/value matrix ordering
if get_checkpoint_version() < 2.0: if get_checkpoint_version() < 2.0:
......
...@@ -133,7 +133,8 @@ def _initialize_distributed(): ...@@ -133,7 +133,8 @@ def _initialize_distributed():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
def _init_autoresume(): def _init_autoresume():
......
...@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module): ...@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage(ignore_virtual=True):
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
......
...@@ -552,7 +552,15 @@ class ParallelTransformer(MegatronModule): ...@@ -552,7 +552,15 @@ class ParallelTransformer(MegatronModule):
layer_number, layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
...@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank ...@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
......
...@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None ...@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -48,7 +51,8 @@ def is_unitialized(): ...@@ -48,7 +51,8 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1): pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -258,17 +268,42 @@ def get_pipeline_model_parallel_rank(): ...@@ -258,17 +268,42 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != 0:
return False
return get_pipeline_model_parallel_rank() == 0 return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(): def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise.""" """Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != (
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - 1):
return False
return get_pipeline_model_parallel_rank() == ( return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
...@@ -276,11 +311,13 @@ def get_tensor_model_parallel_src_rank(): ...@@ -276,11 +311,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size = get_tensor_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0] return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank(): def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
...@@ -294,6 +331,7 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -294,6 +331,7 @@ def get_pipeline_model_parallel_next_rank():
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
...@@ -301,6 +339,7 @@ def get_pipeline_model_parallel_prev_rank(): ...@@ -301,6 +339,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group()) return torch.distributed.get_world_size(group=get_data_parallel_group())
......
...@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module): def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
...@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module): ...@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module in modules:
if isinstance(module_, LayerNorm): for module_ in module.modules():
no_weight_decay_params['params'].extend( if isinstance(module_, LayerNorm):
[p for p in list(module_._parameters.values()) no_weight_decay_params['params'].extend(
if p is not None]) [p for p in list(module_._parameters.values())
else: if p is not None])
weight_decay_params['params'].extend( else:
[p for n, p in list(module_._parameters.items()) weight_decay_params['params'].extend(
if p is not None and n != 'bias']) [p for n, p in list(module_._parameters.items())
no_weight_decay_params['params'].extend( if p is not None and n != 'bias'])
[p for n, p in list(module_._parameters.items()) no_weight_decay_params['params'].extend(
if p is not None and n == 'bias']) [p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args
from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None, use_ring_exchange=False):
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(timers=None, use_ring_exchange=False):
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, timers=None, use_ring_exchange=False):
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
use_ring_exchange=True)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
use_ring_exchange=True)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
use_ring_exchange=True)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import get_num_microbatches
from megatron.p2p_communication import recv_forward, recv_backward
from megatron.p2p_communication import send_forward, send_backward
from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward
from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward
from megatron.p2p_communication import send_forward_backward_recv_forward_backward
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step."""
timers = get_timers()
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
return output_tensor
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
args = get_args()
timers = get_timers()
timers('backward-compute').start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers('backward-compute').stop()
return input_tensor_grad
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes without inter-stage communication."""
assert len(model) == 1
model = model[0]
losses_reduced = []
for i in range(get_num_microbatches()):
input_tensor, output_tensor_grad = None, None
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
return losses_reduced
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size -
mpu.get_pipeline_model_parallel_rank() - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(k, forward):
k_in_group = k % (pipeline_parallel_size * num_model_chunks)
i = k_in_group // pipeline_parallel_size
if not forward:
i = (num_model_chunks - i - 1)
return i
def forward_step_helper(k):
model_chunk_id = get_model_chunk_id(k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(k):
model_chunk_id = get_model_chunk_id(k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
if mpu.is_pipeline_last_stage():
output_tensor = None
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
# If last iteration, don't receive; we already received one extra before the
# start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
# Run cooldown backward passes.
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
recv_backward(timers, use_ring_exchange=True))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(input_tensor_grad, recv_next, timers))
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
send_forward(output_tensor, timers)
else:
output_tensor_grad = send_forward_recv_backward(output_tensor, timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, timers)
else:
input_tensor = send_backward_recv_forward(input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
send_backward(input_tensor_grad, timers)
return losses_reduced
...@@ -48,6 +48,9 @@ from megatron.model.realm_model import ICTBertModel ...@@ -48,6 +48,9 @@ from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -107,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -107,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider,
timers = get_timers() timers = get_timers()
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model and optimizer').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
'scheduler are built') 'scheduler are built')
# Data stuff. # Data stuff.
timers('train/valid/test data iterators').start() timers('train/valid/test-data-iterators-setup').start()
train_data_iterator, valid_data_iterator, test_data_iterator \ if args.virtual_pipeline_model_parallel_size is not None:
= build_train_valid_test_data_iterators( data_iterators = [
train_valid_test_dataset_provider) build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop() for _ in range(len(model))
]
train_data_iterator = [x[0] for x in data_iterators]
valid_data_iterator = [x[1] for x in data_iterators]
test_data_iterator = [x[2] for x in data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built') print_datetime('after dataloaders are built')
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setup ...')
timers.log(['model and optimizer', 'train/valid/test data iterators']) timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
...@@ -185,13 +197,16 @@ def get_model(model_provider_func): ...@@ -185,13 +197,16 @@ def get_model(model_provider_func):
# Build model on cpu. # Build model on cpu.
model = model_provider_func() model = model_provider_func()
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set. # Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these # Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes # attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them. # are set for all params so the optimizer can use them.
for param in model.parameters(): for model_module in model:
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
...@@ -199,22 +214,25 @@ def get_model(model_provider_func): ...@@ -199,22 +214,25 @@ def get_model(model_provider_func):
'model parallel rank ({}, {}): {}'.format( 'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True) sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation. # GPU allocation.
model.cuda(torch.cuda.current_device()) for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion. # Fp16 conversion.
if args.fp16: if args.fp16:
model = FP16Module(model) model = [FP16Module(model_module) for model_module in model]
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i, model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group()) process_group=mpu.get_data_parallel_group())
for model_module in model]
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = LocalDDP(model) model = [LocalDDP(model_module) for model_module in model]
return model return model
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
...@@ -282,11 +300,11 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -282,11 +300,11 @@ def setup_model_and_optimizer(model_provider_func):
# Extra barrier is added to make sure all ranks report the # Extra barrier is added to make sure all ranks report the
# max time. # max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('load checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('load checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load checkpoint']) timers.log(['load-checkpoint'])
else: else:
args.iteration = 0 args.iteration = 0
...@@ -295,292 +313,19 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -295,292 +313,19 @@ def setup_model_and_optimizer(model_provider_func):
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model for module in model:
while hasattr(unwrapped_model, 'module'): unwrapped_module = module
unwrapped_model = unwrapped_model.module while hasattr(unwrapped_module, 'module'):
unwrapped_module = unwrapped_module.module
if args.iteration == 0 and hasattr(unwrapped_model, if args.iteration == 0 and hasattr(unwrapped_module,
'init_state_dict_from_bert'): 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert() unwrapped_module.init_state_dict_from_bert()
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Temporary workaround for batch_isend_irecv() race condition.
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
args = get_args()
timers = get_timers()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
else:
input_tensor = None
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
timers('forward-send').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
timers('backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('backward-recv').stop()
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
timers('backward-send').stop()
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_microbatch,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('forward-send-backward-recv').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send-forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=(not last_microbatch),
recv_backward=False)
timers('backward-send-forward-recv').stop()
else:
input_tensor = None
return input_tensor
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run forward and backward passes without inter-stage communication."""
args = get_args()
losses_reduced = []
for i in range(get_num_microbatches()):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
timers('backward-compute').start()
output_tensor_grad = None
backward_step(optimizer, model, input_tensor=None,
output_tensor=output_tensor, output_tensor_grad=None)
timers('backward-compute').stop()
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args = get_args()
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
timers('forward-recv').start()
input_tensor, _ = communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
# Run 1F1B.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_iteration,
input_tensors, output_tensors,
losses_reduced, timers)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
return losses_reduced
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, lr_scheduler):
"""Single training step.""" """Single training step."""
...@@ -591,17 +336,22 @@ def train_step(forward_step_func, data_iterator, ...@@ -591,17 +336,22 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad() optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_pipelining( if args.virtual_pipeline_model_parallel_size is not None:
forward_step_func, data_iterator, model, optimizer, timers) forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining
else: else:
losses_reduced = forward_backward_no_pipelining( forward_backward_func = forward_backward_no_pipelining
forward_step_func, data_iterator, model, optimizer, timers) losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
model.allreduce_params(reduce_after=False, for model_module in model:
fp32_allreduce=args.fp32_allreduce) model_module.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
...@@ -609,9 +359,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -609,9 +359,12 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
# (BERT and GPT-2). # (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
...@@ -636,7 +389,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -636,7 +389,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches. # Average loss across microbatches.
loss_reduced = {} loss_reduced = {}
for key in losses_reduced[0]: for key in losses_reduced[0]:
...@@ -692,11 +445,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -692,11 +445,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('forward-compute') add_to_logging('forward-compute')
add_to_logging('forward-recv') add_to_logging('forward-recv')
add_to_logging('forward-send') add_to_logging('forward-send')
add_to_logging('forward-send-backward-recv') add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute') add_to_logging('backward-compute')
add_to_logging('backward-recv') add_to_logging('backward-recv')
add_to_logging('backward-send') add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-copy-to-main-grad')
...@@ -745,7 +499,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -745,7 +499,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
normalizer=total_iterations) normalizer=total_iterations)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed() elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
if args.log_timers_to_tensorboard: if args.log_timers_to_tensorboard:
...@@ -794,11 +548,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): ...@@ -794,11 +548,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('save checkpoint').start() timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('save checkpoint').stop() timers('save-checkpoint').stop()
timers.log(['save checkpoint']) timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
...@@ -811,7 +565,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -811,7 +565,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
write_args_to_tensorboard() write_args_to_tensorboard()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() for model_module in model:
model_module.train()
# Tracking loss. # Tracking loss.
total_loss_dict = {} total_loss_dict = {}
...@@ -819,7 +574,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -819,7 +574,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
timers('interval time').start() timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
...@@ -900,7 +655,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -900,7 +655,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args = get_args() args = get_args()
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() for model_module in model:
model_module.eval()
total_loss_dict = {} total_loss_dict = {}
...@@ -912,37 +668,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -912,37 +668,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
for _ in range(get_num_microbatches()): if mpu.get_pipeline_model_parallel_world_size() > 1:
if not mpu.is_pipeline_first_stage(): if args.virtual_pipeline_model_parallel_size is not None:
input_tensor, _ = communicate( forward_backward_func = forward_backward_pipelining_with_interleaving
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else: else:
input_tensor = None forward_backward_func = forward_backward_pipelining
else:
# Forward evaluation. forward_backward_func = forward_backward_no_pipelining
output_tensor = forward_step_func(data_iterator, model, input_tensor) loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
if mpu.is_pipeline_last_stage(): timers=None, forward_only=True)
_, loss_dict = output_tensor
# Reduce across processes. if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
loss_dict[key] loss_dict[key]
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
* get_num_microbatches() * get_num_microbatches()
# Move model back to the train mode. # Move model back to the train mode.
model.train() for model_module in model:
model_module.train()
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
......
...@@ -38,7 +38,7 @@ def model_provider(): ...@@ -38,7 +38,7 @@ def model_provider():
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
...@@ -51,6 +51,17 @@ def model_provider(): ...@@ -51,6 +51,17 @@ def model_provider():
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = BertModel( model = BertModel(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
= get_batch(data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head: if not args.bert_binary_head:
......
...@@ -35,8 +35,8 @@ def model_provider(): ...@@ -35,8 +35,8 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
...@@ -46,6 +46,17 @@ def model_provider(): ...@@ -46,6 +46,17 @@ def model_provider():
else: else:
model = GPTModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = GPTModel(num_tokentypes=0, parallel_output=True) model = GPTModel(num_tokentypes=0, parallel_output=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment