Commit 5489bda9 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

More comments and some cleanup (e.g., better variable names)

parent 626645c0
......@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule):
self_attn_mask_type=self_attn_mask_type)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
......@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank():
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != 0:
if get_virtual_pipeline_model_parallel_world_size() is not None and \
get_virtual_pipeline_model_parallel_rank() != 0:
return False
return get_pipeline_model_parallel_rank() == 0
......@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False):
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != (
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - 1):
virtual_pipeline_model_parallel_world_size = \
get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
return False
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
......@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank):
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
......
......@@ -23,7 +23,24 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages."""
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
......@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
tensor_recv_prev_before = tensor_recv_prev
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
......@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
def recv_forward(timers=None, use_ring_exchange=False):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
......@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False):
def recv_backward(timers=None, use_ring_exchange=False):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
......@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False):
def send_forward(output_tensor, timers=None, use_ring_exchange=False):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
......@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
......@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
......@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
......@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
......@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
......@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
......
......@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(k, forward):
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
k_in_group = k % (pipeline_parallel_size * num_model_chunks)
i = k_in_group // pipeline_parallel_size
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
i = (num_model_chunks - i - 1)
return i
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(k):
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=True)
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
......@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return output_tensor
def backward_step_helper(k):
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=False)
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
......@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
......
......@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory
......@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator,
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
losses_reduced = forward_backward_func(
......@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
loss_dicts = forward_backward_func(
......@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
loss_dict[key]
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment