Commit b93bef00 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

comments, cleanup.

parent bea16fa3
......@@ -141,24 +141,9 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
# >>>
# args.virtual_pipeline_model_parallel_size = \
# (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage
# <<<
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
# >>>
# from lutil import pax
# pax({
# "num_layers" : args.num_layers,
# "pipeline size" : args.pipeline_model_parallel_size,
# "transformer size" : transformer_pipeline_size,
# "num virt layers" : args.num_layers_per_virtual_pipeline_stage,
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
else:
args.virtual_pipeline_model_parallel_size = None
......@@ -707,7 +692,8 @@ def _add_distributed_args(parser):
group.add_argument('--standalone-embed-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers.')
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser
......
......@@ -542,12 +542,20 @@ class ParallelTransformerLayer(MegatronModule):
return output
# >>>
class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when args.standalone_embed_stage
== True. ?????
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embed_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def __init__(self, layer_number):
......@@ -558,7 +566,6 @@ class NoopTransformerLayer(MegatronModule):
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
return hidden_states.clone()
# <<<
class ParallelTransformer(MegatronModule):
......@@ -583,19 +590,8 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers.
# >>>
# raise Exception("rank %d." % torch.distributed.get_rank())
# <<<
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
# >>>
# if not self.pre_process and self.num_layers == 0:
# raise Exception(">>>> t %d, p %d, v %d. <<<<" % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# Transformer layers.
def build_layer(layer_number):
......@@ -637,28 +633,20 @@ class ParallelTransformer(MegatronModule):
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
# >>>
if self.num_layers == 0:
# when args.standalone_embed_stage == True, virtual pipeline ranks
# When a standalone embedding stage is used (e.g.,
# args.standalone_embed_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This will cause a couple optimization techniques to fail:
#
# 1. distributed checkpointing (we
# 2. pipeline output tensor deallocation (would fail because the
# output tensor is the same object as the input tensor, and
# thus we also deallocate the input tensor, which causes
# autograd.backward to fail)
#
# to remedy this, we assign a 'no-op' layer on these ranks, which
# will pass the data flow through the checkpoint function, and in
# turn also results in the schedule's input and output tensors
# being separate objects.
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
# <<<
if self.post_process:
# Final layer norm before output.
......@@ -745,18 +733,6 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
# >>>
# if not self.pre_process and self.num_layers == 0:
# # raise Exception("tp %d, pp %d, vp %d ... hidden states %s, input tensor %s." % (
# # mpu.get_tensor_model_parallel_rank(),
# # mpu.get_pipeline_model_parallel_rank(),
# # mpu.get_virtual_pipeline_model_parallel_rank(),
# # "--" if hidden_states is None else str(hidden_states.shape),
# # "--" if self.input_tensor is None else str(self.input_tensor.shape),
# # ))
# hidden_states = hidden_states.clone()
# <<<
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
......@@ -804,26 +780,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
# >>>
# if True or output._base is not None:
# # from lutil import pax, tp
# # pax({
# # "hidden_states" : tp(hidden_states),
# # "output" : tp(output),
# # })
# # raise Exception(">>> rank %d, view %d, hid '%s', out '%s'. <<<" %(
# # torch.distributed.get_rank(),
# # output._base is not None,
# # str(hidden_states.shape),
# # str(output.shape),
# # ))
# args = get_args()
# raise Exception(">>> rank %d, hid %d, view %d. <<<" %(
# torch.distributed.get_rank(),
# args.hidden_size,
# output._base is not None,
# ))
# <<<
else:
output = hidden_states
......
......@@ -269,9 +269,6 @@ def set_tensor_model_parallel_world_size(world_size):
def set_pipeline_model_parallel_world_size(world_size):
# >>>
raise Exception("hi.")
# <<<
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
......@@ -290,9 +287,6 @@ def get_pipeline_model_parallel_world_size():
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
# >>>
# raise Exception("hi.")
# <<<
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
......@@ -328,49 +322,34 @@ def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
# >>>
# raise Exception("fix for t5.")
# <<<
assert args.pipeline_model_parallel_split_rank is not None
# >>>
# num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
# +++
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_split_rank
)
# <<<
# >>>
# num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
# +++
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
# <<<
# >>>
# raise Exception(">>>> standalone %d, encoder %d, decoder %d. <<<<" % (
# args.standalone_embed_stage,
# num_ranks_in_encoder,
# num_ranks_in_decoder,
# ))
# <<<
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): # args):
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
# >>>
# transformer_pipeline_size = (
# get_pipeline_model_parallel_world_size() - 1
# if args.standalone_embed_stage else
# get_pipeline_model_parallel_world_size()
# )
# <<<
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embed_stage
......@@ -379,17 +358,6 @@ def get_num_layers(args, is_encoder_and_decoder_model):
)
else:
num_layers = args.num_layers
# >>>
# from lutil import pax
# pax(7, {
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : "%d / %d" % (
# get_pipeline_model_parallel_rank(),
# get_pipeline_model_parallel_world_size(),
# ),
# "num_layers" : num_layers,
# })
# <<<
return num_layers
......@@ -438,9 +406,6 @@ def is_rank_in_position_embedding_group():
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
# >>>
# def is_pipeline_stage_before_split(args, rank=None):
# <<<
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
......@@ -448,11 +413,6 @@ def is_pipeline_stage_before_split(rank=None):
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
......@@ -461,9 +421,6 @@ def is_pipeline_stage_before_split(rank=None):
return False
# >>>
# def is_pipeline_stage_after_split(args, rank=None):
# <<<
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
......@@ -471,11 +428,6 @@ def is_pipeline_stage_after_split(rank=None):
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
......
......@@ -136,35 +136,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# >>>
def make_viewless_tensor(t):
return mpu.make_viewless_tensor(t, requires_grad=True, keep_graph=False)
# <<<
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
# >>>
# tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
# requires_grad = True,
# keep_graph = False)
# +++
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
# <<<
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
# >>>
# tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
# requires_grad = True,
# keep_graph = False)
# +++
tensor_recv_next = make_viewless_tensor(tensor_recv_next)
# <<<
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next
......
......@@ -34,25 +34,6 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
# >>>
# from lutil import pax
# pax({
# "num microbatches" : get_num_microbatches(),
# "pipeline size" : args.pipeline_model_parallel_size,
# })
# <<<
# >>>
# assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
# 'number of microbatches is not divisible by pipeline-parallel ' \
# 'size when using interleaved schedule'
# assert get_num_microbatches() % \
# args.transformer_pipeline_model_parallel_size == 0, \
# 'number of microbatches (%d) is not divisible by transformer-' \
# 'pipeline-model-parallel-size (%d) when using interleaved ' \
# 'schedule' % (
# get_num_microbatches(),
# args.transformer_pipeline_model_parallel_size,
# )
assert get_num_microbatches() % \
args.pipeline_model_parallel_size == 0, \
'number of microbatches (%d) is not divisible by pipeline-' \
......@@ -60,7 +41,6 @@ def get_forward_backward_func():
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
# <<<
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
......@@ -143,9 +123,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
# >>>
mpu.assert_viewless_tensor(output_tensor)
# <<<
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
......@@ -153,10 +130,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
# >>>
mpu.assert_viewless_tensor(output_tensor)
# <<<
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
......@@ -341,15 +314,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
# >>>
# if id(input_tensor) == id(output_tensor):
# raise Exception("tp %d, pp %d, vp %d." % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
......
......@@ -136,14 +136,6 @@ def pretrain(train_valid_test_dataset_provider,
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
# >>>
# from lutil import pax
# pax({
# "model / len" : len(model),
# # "do_train": args.do_train,
# })
# <<<
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
......@@ -207,14 +199,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args = get_args()
args.model_type = model_type
# >>>
# from lutil import pax
# pax({
# "pipeline world size" : mpu.get_pipeline_model_parallel_world_size(),
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
......@@ -232,13 +216,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
)
this_model.model_type = model_type
model.append(this_model)
# >>>
# from lutil import pax
# pax({
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# "model" : model,
# })
# <<<
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
......@@ -254,10 +231,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
# >>>
add_encoder = mpu.is_pipeline_stage_before_split() # args)
add_decoder = mpu.is_pipeline_stage_after_split() # args)
# <<<
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
......@@ -370,21 +345,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model = get_model(model_provider_func, model_type)
# >>>
# if mpu.get_tensor_model_parallel_rank() == 0:
# from lutil import pax
# pax({
# # "model" : model,
# "model" : [
# sum(t.nelement() for t in m.parameters())
# for m in model
# ],
# })
# else:
# torch.distributed.barrier()
# exit(0)
# <<<
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
......@@ -952,11 +912,6 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
# >>>
# from lutil import pax
# pax({"hi": "there"})
# <<<
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment