Commit c2b7d0b3 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed args.virtual_pipeline_model_parallel_size

parent 33dc8e9c
......@@ -136,9 +136,29 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
# >>>
# args.virtual_pipeline_model_parallel_size = \
# (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage
# <<<
transformer_pipeline_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_size
)
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
(args.num_layers // transformer_pipeline_size) // \
args.num_layers_per_virtual_pipeline_stage
# >>>
# from lutil import pax
# pax({
# "num_layers" : args.num_layers,
# "pipeline size" : args.pipeline_model_parallel_size,
# "transformer size" : transformer_pipeline_size,
# "num virt layers" : args.num_layers_per_virtual_pipeline_stage,
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
else:
args.virtual_pipeline_model_parallel_size = None
......
......@@ -565,6 +565,9 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers.
# >>>
# raise Exception("rank %d." % torch.distributed.get_rank())
# <<<
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
......
......@@ -356,16 +356,19 @@ def get_num_layers(args, is_encoder_and_decoder_model):
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // transformer_pipeline_size
)
# >>>
# from lutil import pax
# pax({
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : get_pipeline_model_parallel_rank(),
# "num_layers" : num_layers,
# })
# <<<
else:
num_layers = args.num_layers
# >>>
from lutil import pax
pax(0, {
"rank" : torch.distributed.get_rank(),
"pipeline rank" : "%d / %d" % (
get_pipeline_model_parallel_rank(),
get_pipeline_model_parallel_world_size(),
),
"num_layers" : num_layers,
})
# <<<
return num_layers
......
......@@ -136,6 +136,14 @@ def pretrain(train_valid_test_dataset_provider,
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
# >>>
from lutil import pax
pax({
"model / len" : len(model),
# "do_train": args.do_train,
})
# <<<
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
......@@ -199,6 +207,14 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args = get_args()
args.model_type = model_type
# >>>
# from lutil import pax
# pax({
# "pipeline world size" : mpu.get_pipeline_model_parallel_world_size(),
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
......@@ -216,6 +232,13 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
)
this_model.model_type = model_type
model.append(this_model)
# >>>
from lutil import pax
pax({
"virtual size" : args.virtual_pipeline_model_parallel_size,
"model" : model,
})
# <<<
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
......@@ -342,6 +365,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model = get_model(model_provider_func, model_type)
# >>>
from lutil import pax
pax({"model": model})
# <<<
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
......@@ -909,6 +937,10 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
# >>>
from lutil import pax
pax({"hi": "there"})
# <<<
# Build iterators.
dl_type = args.dataloader_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment