Commit b93bef00 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

comments, cleanup.

parent bea16fa3
...@@ -141,24 +141,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -141,24 +141,9 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \ 'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage' 'pipeline stage'
# >>>
# args.virtual_pipeline_model_parallel_size = \
# (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage
# <<<
args.virtual_pipeline_model_parallel_size = \ args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \ (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage args.num_layers_per_virtual_pipeline_stage
# >>>
# from lutil import pax
# pax({
# "num_layers" : args.num_layers,
# "pipeline size" : args.pipeline_model_parallel_size,
# "transformer size" : transformer_pipeline_size,
# "num virt layers" : args.num_layers_per_virtual_pipeline_stage,
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
else: else:
args.virtual_pipeline_model_parallel_size = None args.virtual_pipeline_model_parallel_size = None
...@@ -707,7 +692,8 @@ def _add_distributed_args(parser): ...@@ -707,7 +692,8 @@ def _add_distributed_args(parser):
group.add_argument('--standalone-embed-stage', action='store_true', group.add_argument('--standalone-embed-stage', action='store_true',
default=False, help='If set, *input* embedding layer ' default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any ' 'is placed on its own pipeline stage, without any '
'transformer layers.') 'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser return parser
......
...@@ -542,12 +542,20 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -542,12 +542,20 @@ class ParallelTransformerLayer(MegatronModule):
return output return output
# >>>
class NoopTransformerLayer(MegatronModule): class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer. """A single 'no-op' transformer layer.
The sole purpose of this layer is for when args.standalone_embed_stage The sole purpose of this layer is for when a standalone embedding layer
== True. ????? is used (i.e., args.standalone_embed_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
""" """
def __init__(self, layer_number): def __init__(self, layer_number):
...@@ -558,7 +566,6 @@ class NoopTransformerLayer(MegatronModule): ...@@ -558,7 +566,6 @@ class NoopTransformerLayer(MegatronModule):
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
return hidden_states.clone() return hidden_states.clone()
# <<<
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
...@@ -583,19 +590,8 @@ class ParallelTransformer(MegatronModule): ...@@ -583,19 +590,8 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers. # Number of layers.
# >>>
# raise Exception("rank %d." % torch.distributed.get_rank())
# <<<
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
# >>>
# if not self.pre_process and self.num_layers == 0:
# raise Exception(">>>> t %d, p %d, v %d. <<<<" % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
...@@ -637,28 +633,20 @@ class ParallelTransformer(MegatronModule): ...@@ -637,28 +633,20 @@ class ParallelTransformer(MegatronModule):
else: else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
# >>>
if self.num_layers == 0: if self.num_layers == 0:
# when args.standalone_embed_stage == True, virtual pipeline ranks # When a standalone embedding stage is used (e.g.,
# args.standalone_embed_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to # on pipeline rank 0 will have zero transformer layers assigned to
# them. This will cause a couple optimization techniques to fail: # them. This results in the model's input and output tensors to be
# # the same, which will cause failure for certain output tensor
# 1. distributed checkpointing (we # optimizations (e.g., pipeline output deallocation). To remedy
# 2. pipeline output tensor deallocation (would fail because the # this, we assign a 'no-op' layer on these ranks, which will
# output tensor is the same object as the input tensor, and # disconnect the input tensor from the output tensor.
# thus we also deallocate the input tensor, which causes
# autograd.backward to fail)
#
# to remedy this, we assign a 'no-op' layer on these ranks, which
# will pass the data flow through the checkpoint function, and in
# turn also results in the schedule's input and output tensors
# being separate objects.
self.num_layers = 1 self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else: else:
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
# <<<
if self.post_process: if self.post_process:
# Final layer norm before output. # Final layer norm before output.
...@@ -745,18 +733,6 @@ class ParallelTransformer(MegatronModule): ...@@ -745,18 +733,6 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor() # See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
# >>>
# if not self.pre_process and self.num_layers == 0:
# # raise Exception("tp %d, pp %d, vp %d ... hidden states %s, input tensor %s." % (
# # mpu.get_tensor_model_parallel_rank(),
# # mpu.get_pipeline_model_parallel_rank(),
# # mpu.get_virtual_pipeline_model_parallel_rank(),
# # "--" if hidden_states is None else str(hidden_states.shape),
# # "--" if self.input_tensor is None else str(self.input_tensor.shape),
# # ))
# hidden_states = hidden_states.clone()
# <<<
# Viewless tensor. # Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch # - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()' # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
...@@ -804,26 +780,6 @@ class ParallelTransformer(MegatronModule): ...@@ -804,26 +780,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
# >>>
# if True or output._base is not None:
# # from lutil import pax, tp
# # pax({
# # "hidden_states" : tp(hidden_states),
# # "output" : tp(output),
# # })
# # raise Exception(">>> rank %d, view %d, hid '%s', out '%s'. <<<" %(
# # torch.distributed.get_rank(),
# # output._base is not None,
# # str(hidden_states.shape),
# # str(output.shape),
# # ))
# args = get_args()
# raise Exception(">>> rank %d, hid %d, view %d. <<<" %(
# torch.distributed.get_rank(),
# args.hidden_size,
# output._base is not None,
# ))
# <<<
else: else:
output = hidden_states output = hidden_states
......
...@@ -269,9 +269,6 @@ def set_tensor_model_parallel_world_size(world_size): ...@@ -269,9 +269,6 @@ def set_tensor_model_parallel_world_size(world_size):
def set_pipeline_model_parallel_world_size(world_size): def set_pipeline_model_parallel_world_size(world_size):
# >>>
raise Exception("hi.")
# <<<
"""Set the pipeline model parallel size""" """Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
...@@ -290,9 +287,6 @@ def get_pipeline_model_parallel_world_size(): ...@@ -290,9 +287,6 @@ def get_pipeline_model_parallel_world_size():
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
# >>>
# raise Exception("hi.")
# <<<
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
...@@ -328,49 +322,34 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -328,49 +322,34 @@ def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank.""" """Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1: if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
# >>>
# raise Exception("fix for t5.")
# <<<
assert args.pipeline_model_parallel_split_rank is not None assert args.pipeline_model_parallel_split_rank is not None
# >>>
# num_ranks_in_encoder = args.pipeline_model_parallel_split_rank # When a standalone embedding stage is used, a rank is taken from
# +++ # the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = ( num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1 args.pipeline_model_parallel_split_rank - 1
if args.standalone_embed_stage else if args.standalone_embed_stage else
args.pipeline_model_parallel_split_rank args.pipeline_model_parallel_split_rank
) )
# <<<
# >>>
# num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
# +++
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
# <<<
# >>>
# raise Exception(">>>> standalone %d, encoder %d, decoder %d. <<<<" % (
# args.standalone_embed_stage,
# num_ranks_in_encoder,
# num_ranks_in_decoder,
# ))
# <<<
assert args.num_layers % num_ranks_in_encoder == 0, \ assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder) 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \ assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder) 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): # args): if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder num_layers = args.num_layers // num_ranks_in_encoder
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.num_layers // num_ranks_in_decoder
else: else:
# >>>
# transformer_pipeline_size = (
# get_pipeline_model_parallel_world_size() - 1
# if args.standalone_embed_stage else
# get_pipeline_model_parallel_world_size()
# )
# <<<
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size' 'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = ( num_layers = (
0 0
if args.standalone_embed_stage if args.standalone_embed_stage
...@@ -379,17 +358,6 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -379,17 +358,6 @@ def get_num_layers(args, is_encoder_and_decoder_model):
) )
else: else:
num_layers = args.num_layers num_layers = args.num_layers
# >>>
# from lutil import pax
# pax(7, {
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : "%d / %d" % (
# get_pipeline_model_parallel_rank(),
# get_pipeline_model_parallel_world_size(),
# ),
# "num_layers" : num_layers,
# })
# <<<
return num_layers return num_layers
...@@ -438,9 +406,6 @@ def is_rank_in_position_embedding_group(): ...@@ -438,9 +406,6 @@ def is_rank_in_position_embedding_group():
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
# >>>
# def is_pipeline_stage_before_split(args, rank=None):
# <<<
def is_pipeline_stage_before_split(rank=None): def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model """Return True if pipeline stage executes encoder block for a model
with both encoder and decoder.""" with both encoder and decoder."""
...@@ -448,11 +413,6 @@ def is_pipeline_stage_before_split(rank=None): ...@@ -448,11 +413,6 @@ def is_pipeline_stage_before_split(rank=None):
return True return True
if rank is None: if rank is None:
rank = get_pipeline_model_parallel_rank() rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True return True
...@@ -461,9 +421,6 @@ def is_pipeline_stage_before_split(rank=None): ...@@ -461,9 +421,6 @@ def is_pipeline_stage_before_split(rank=None):
return False return False
# >>>
# def is_pipeline_stage_after_split(args, rank=None):
# <<<
def is_pipeline_stage_after_split(rank=None): def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model """Return True if pipeline stage executes decoder block for a model
with both encoder and decoder.""" with both encoder and decoder."""
...@@ -471,11 +428,6 @@ def is_pipeline_stage_after_split(rank=None): ...@@ -471,11 +428,6 @@ def is_pipeline_stage_after_split(rank=None):
return True return True
if rank is None: if rank is None:
rank = get_pipeline_model_parallel_rank() rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True return True
......
...@@ -136,35 +136,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -136,35 +136,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
# >>>
def make_viewless_tensor(t):
return mpu.make_viewless_tensor(t, requires_grad=True, keep_graph=False)
# <<<
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \ if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline: args.scatter_gather_tensors_in_pipeline:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
# >>> tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
# tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev, requires_grad = True,
# requires_grad = True, keep_graph = False)
# keep_graph = False)
# +++
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
# <<<
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
# >>> tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
# tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next, requires_grad = True,
# requires_grad = True, keep_graph = False)
# keep_graph = False)
# +++
tensor_recv_next = make_viewless_tensor(tensor_recv_next)
# <<<
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -34,25 +34,6 @@ def get_forward_backward_func(): ...@@ -34,25 +34,6 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
# >>>
# from lutil import pax
# pax({
# "num microbatches" : get_num_microbatches(),
# "pipeline size" : args.pipeline_model_parallel_size,
# })
# <<<
# >>>
# assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
# 'number of microbatches is not divisible by pipeline-parallel ' \
# 'size when using interleaved schedule'
# assert get_num_microbatches() % \
# args.transformer_pipeline_model_parallel_size == 0, \
# 'number of microbatches (%d) is not divisible by transformer-' \
# 'pipeline-model-parallel-size (%d) when using interleaved ' \
# 'schedule' % (
# get_num_microbatches(),
# args.transformer_pipeline_model_parallel_size,
# )
assert get_num_microbatches() % \ assert get_num_microbatches() % \
args.pipeline_model_parallel_size == 0, \ args.pipeline_model_parallel_size == 0, \
'number of microbatches (%d) is not divisible by pipeline-' \ 'number of microbatches (%d) is not divisible by pipeline-' \
...@@ -60,7 +41,6 @@ def get_forward_backward_func(): ...@@ -60,7 +41,6 @@ def get_forward_backward_func():
get_num_microbatches(), get_num_microbatches(),
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
) )
# <<<
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
...@@ -143,9 +123,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -143,9 +123,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
# >>>
mpu.assert_viewless_tensor(output_tensor)
# <<<
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
...@@ -153,10 +130,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -153,10 +130,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
timers('forward-compute').stop() timers('forward-compute').stop()
# >>>
mpu.assert_viewless_tensor(output_tensor)
# <<<
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state # and in decoder stack, then send encoder_hidden_state
# downstream as well. # downstream as well.
...@@ -341,15 +314,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -341,15 +314,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced) input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# >>>
# if id(input_tensor) == id(output_tensor):
# raise Exception("tp %d, pp %d, vp %d." % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
if forward_only: if forward_only:
input_tensors[model_chunk_id].pop() input_tensors[model_chunk_id].pop()
......
...@@ -136,14 +136,6 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -136,14 +136,6 @@ def pretrain(train_valid_test_dataset_provider,
timers('train/valid/test-data-iterators-setup').stop() timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built') print_datetime('after dataloaders are built')
# >>>
# from lutil import pax
# pax({
# "model / len" : len(model),
# # "do_train": args.do_train,
# })
# <<<
# Print setup timing. # Print setup timing.
print_rank_0('done with setup ...') print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
...@@ -207,14 +199,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -207,14 +199,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args = get_args() args = get_args()
args.model_type = model_type args.model_type = model_type
# >>>
# from lutil import pax
# pax({
# "pipeline world size" : mpu.get_pipeline_model_parallel_world_size(),
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
# Build model. # Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \ if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None: args.virtual_pipeline_model_parallel_size is not None:
...@@ -232,13 +216,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -232,13 +216,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
) )
this_model.model_type = model_type this_model.model_type = model_type
model.append(this_model) model.append(this_model)
# >>>
# from lutil import pax
# pax({
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# "model" : model,
# })
# <<<
else: else:
pre_process = mpu.is_pipeline_first_stage() pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage() post_process = mpu.is_pipeline_last_stage()
...@@ -254,10 +231,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -254,10 +231,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
pre_process = rank == 0 or rank == split_rank pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or ( post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1)) rank == (world_size - 1))
# >>> add_encoder = mpu.is_pipeline_stage_before_split()
add_encoder = mpu.is_pipeline_stage_before_split() # args) add_decoder = mpu.is_pipeline_stage_after_split()
add_decoder = mpu.is_pipeline_stage_after_split() # args)
# <<<
model = model_provider_func( model = model_provider_func(
pre_process=pre_process, pre_process=pre_process,
post_process=post_process, post_process=post_process,
...@@ -370,21 +345,6 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -370,21 +345,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
# >>>
# if mpu.get_tensor_model_parallel_rank() == 0:
# from lutil import pax
# pax({
# # "model" : model,
# "model" : [
# sum(t.nelement() for t in m.parameters())
# for m in model
# ],
# })
# else:
# torch.distributed.barrier()
# exit(0)
# <<<
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model)
...@@ -952,11 +912,6 @@ def build_train_valid_test_data_iterators( ...@@ -952,11 +912,6 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
# >>>
# from lutil import pax
# pax({"hi": "there"})
# <<<
# Build iterators. # Build iterators.
dl_type = args.dataloader_type dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic'] assert dl_type in ['single', 'cyclic']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment