Commit dd889062 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Interleaved pipeline execution and code refactoring

- Split a model's computation into multiple virtual stages as needed,
and schedule communication correctly between these virtual stages
- Move schedule code into `schedules.py` and communication code into
`p2p_communication.py`
- Use hyphens instead of spaces in all time logging for consistency
- Factor out code in megatron/training.py into helper functions
- Refactor evaluate() function: make it use forward_backward_schedule
functions
parent c1faa9fe
......@@ -557,6 +557,8 @@ def _add_distributed_args(parser):
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None,
help='Number of virtual pipeline stages in physical stage.')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
......
......@@ -111,8 +111,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
unwrapped_model = []
for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
......@@ -124,7 +128,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()
if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
......@@ -211,8 +220,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args = get_args()
load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP):
model = model.module
unwrapped_model = []
for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
......@@ -297,7 +311,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0('could not find arguments in the checkpoint ...')
# Model.
model.load_state_dict(state_dict['model'], strict=strict)
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering
if get_checkpoint_version() < 2.0:
......
......@@ -133,7 +133,8 @@ def _initialize_distributed():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
def _init_autoresume():
......
......@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
if mpu.is_pipeline_first_stage(ignore_virtual=True):
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
......
......@@ -552,7 +552,15 @@ class ParallelTransformer(MegatronModule):
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
......@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
......
......@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
......@@ -48,7 +51,8 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1):
pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None):
"""
Initialize model data parallel groups.
......@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
......@@ -258,17 +268,42 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage():
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != 0:
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage():
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != (
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - 1):
return False
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
......@@ -276,11 +311,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
......@@ -294,6 +331,7 @@ def get_pipeline_model_parallel_next_rank():
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
......@@ -301,6 +339,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
......
......@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module):
def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
......@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
for module in modules:
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args
from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None, use_ring_exchange=False):
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(timers=None, use_ring_exchange=False):
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, timers=None, use_ring_exchange=False):
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
use_ring_exchange=True)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
use_ring_exchange=True)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
use_ring_exchange=True)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import get_num_microbatches
from megatron.p2p_communication import recv_forward, recv_backward
from megatron.p2p_communication import send_forward, send_backward
from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward
from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward
from megatron.p2p_communication import send_forward_backward_recv_forward_backward
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step."""
timers = get_timers()
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
return output_tensor
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
args = get_args()
timers = get_timers()
timers('backward-compute').start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers('backward-compute').stop()
return input_tensor_grad
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes without inter-stage communication."""
assert len(model) == 1
model = model[0]
losses_reduced = []
for i in range(get_num_microbatches()):
input_tensor, output_tensor_grad = None, None
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
return losses_reduced
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size -
mpu.get_pipeline_model_parallel_rank() - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(k, forward):
k_in_group = k % (pipeline_parallel_size * num_model_chunks)
i = k_in_group // pipeline_parallel_size
if not forward:
i = (num_model_chunks - i - 1)
return i
def forward_step_helper(k):
model_chunk_id = get_model_chunk_id(k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(k):
model_chunk_id = get_model_chunk_id(k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
if mpu.is_pipeline_last_stage():
output_tensor = None
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
# If last iteration, don't receive; we already received one extra before the
# start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
# Run cooldown backward passes.
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
recv_backward(timers, use_ring_exchange=True))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(input_tensor_grad, recv_next, timers))
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
send_forward(output_tensor, timers)
else:
output_tensor_grad = send_forward_recv_backward(output_tensor, timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, timers)
else:
input_tensor = send_backward_recv_forward(input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
send_backward(input_tensor_grad, timers)
return losses_reduced
This diff is collapsed.
......@@ -38,7 +38,7 @@ def model_provider():
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1:
def model_provider_pipelined():
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
......@@ -51,6 +51,17 @@ def model_provider():
else:
model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else:
model = BertModel(
num_tokentypes=num_tokentypes,
......@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator)
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
timers('batch-generator').stop()
if not args.bert_binary_head:
......
......@@ -35,8 +35,8 @@ def model_provider():
"""Build the model."""
print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
def model_provider_pipelined():
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0)
......@@ -46,6 +46,17 @@ def model_provider():
else:
model = GPTModelIntermediateStage(
num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else:
model = GPTModel(num_tokentypes=0, parallel_output=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment