Commit 5489bda9 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

More comments and some cleanup (e.g., better variable names)

parent 626645c0
...@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule): ...@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule):
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size' 'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
...@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank(): ...@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank():
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK if get_virtual_pipeline_model_parallel_world_size() is not None and \
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE get_virtual_pipeline_model_parallel_rank() != 0:
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != 0:
return False return False
return get_pipeline_model_parallel_rank() == 0 return get_pipeline_model_parallel_rank() == 0
...@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False): ...@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False):
def is_pipeline_last_stage(ignore_virtual=False): def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise.""" """Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK virtual_pipeline_model_parallel_world_size = \
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE get_virtual_pipeline_model_parallel_world_size()
if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \ if virtual_pipeline_model_parallel_world_size is not None and \
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != ( get_virtual_pipeline_model_parallel_rank() != (
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - 1): virtual_pipeline_model_parallel_world_size - 1):
return False return False
return get_pipeline_model_parallel_rank() == ( return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
...@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank): ...@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank):
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
......
...@@ -23,7 +23,24 @@ from megatron import mpu ...@@ -23,7 +23,24 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False): use_ring_exchange=False):
"""Communicate tensors between stages.""" """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args = get_args() args = get_args()
# Create placeholder tensors for receive in forward and backward directions # Create placeholder tensors for receive in forward and backward directions
...@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline: if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
...@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
else: else:
ops = [] ops = []
if tensor_send_prev is not None: if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev, send_prev_op = torch.distributed.P2POp(
mpu.get_pipeline_model_parallel_prev_rank()) torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op) ops.append(send_prev_op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, recv_prev_op = torch.distributed.P2POp(
mpu.get_pipeline_model_parallel_prev_rank()) torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op) ops.append(recv_prev_op)
if tensor_send_next is not None: if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, send_next_op = torch.distributed.P2POp(
mpu.get_pipeline_model_parallel_next_rank()) torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op) ops.append(send_next_op)
if tensor_recv_next is not None: if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next, recv_next_op = torch.distributed.P2POp(
mpu.get_pipeline_model_parallel_next_rank()) torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op) ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
tensor_recv_prev_before = tensor_recv_prev # If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline: if args.scatter_gather_tensors_in_pipeline:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
...@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
def recv_forward(timers=None, use_ring_exchange=False): def recv_forward(timers=None, use_ring_exchange=False):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False): ...@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False):
def recv_backward(timers=None, use_ring_exchange=False): def recv_backward(timers=None, use_ring_exchange=False):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
...@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False): ...@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False):
def send_forward(output_tensor, timers=None, use_ring_exchange=False): def send_forward(output_tensor, timers=None, use_ring_exchange=False):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
timers('forward-send').start() timers('forward-send').start()
...@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False): ...@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
timers('backward-send').start() timers('backward-send').start()
...@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): ...@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False): def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
...@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal ...@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange ...@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
def send_forward_recv_forward(output_tensor, recv_prev, timers=None): def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').start() timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
...@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): ...@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').start() timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
...@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None): recv_next, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').start() timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate( input_tensor, output_tensor_grad = _communicate(
......
...@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
def get_model_chunk_id(k, forward): def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number.""" """Helper method to get the model chunk ID given the iteration number."""
k_in_group = k % (pipeline_parallel_size * num_model_chunks) microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
i = k_in_group // pipeline_parallel_size model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward: if not forward:
i = (num_model_chunks - i - 1) model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return i return model_chunk_id
def forward_step_helper(k): def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks """Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
forward_step()).""" forward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=True) model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return output_tensor return output_tensor
def backward_step_helper(k): def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
backward_step()).""" backward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=False) model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return losses_reduced return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model, def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
optimizer, timers, forward_only): model, optimizer, timers,
forward_only):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
......
...@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model ...@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator,
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
else: else:
forward_backward_func = forward_backward_pipelining forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
...@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
else: else:
forward_backward_func = forward_backward_pipelining forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
loss_dicts = forward_backward_func( loss_dicts = forward_backward_func(
...@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes. # Reduce across processes.
for loss_dict in loss_dicts: for loss_dict in loss_dicts:
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ total_loss_dict[key] = total_loss_dict.get(
loss_dict[key] key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment