Commit a06af061 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added args.transformer_pipeline_model_parallel_size

parent c2b7d0b3
......@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
......@@ -141,11 +146,6 @@ def parse_args(extra_args_provider=None, defaults={},
# (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage
# <<<
transformer_pipeline_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_size
)
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // transformer_pipeline_size) // \
args.num_layers_per_virtual_pipeline_stage
......
......@@ -343,11 +343,13 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
transformer_pipeline_size = (
get_pipeline_model_parallel_world_size() - 1
if args.standalone_embed_stage else
get_pipeline_model_parallel_world_size()
)
# >>>
# transformer_pipeline_size = (
# get_pipeline_model_parallel_world_size() - 1
# if args.standalone_embed_stage else
# get_pipeline_model_parallel_world_size()
# )
# <<<
assert args.num_layers % transformer_pipeline_size == 0, \
'num_layers must be divisible by transformer_pipeline_size'
num_layers = (
......@@ -359,15 +361,15 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else:
num_layers = args.num_layers
# >>>
from lutil import pax
pax(0, {
"rank" : torch.distributed.get_rank(),
"pipeline rank" : "%d / %d" % (
get_pipeline_model_parallel_rank(),
get_pipeline_model_parallel_world_size(),
),
"num_layers" : num_layers,
})
# from lutil import pax
# pax(7, {
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : "%d / %d" % (
# get_pipeline_model_parallel_rank(),
# get_pipeline_model_parallel_world_size(),
# ),
# "num_layers" : num_layers,
# })
# <<<
return num_layers
......
......@@ -33,6 +33,13 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
# >>>
# from lutil import pax
# pax({
# "num microbatches" : get_num_microbatches(),
# "pipeline size" : args.pipeline_model_parallel_size,
# })
# <<<
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
......
......@@ -137,11 +137,11 @@ def pretrain(train_valid_test_dataset_provider,
print_datetime('after dataloaders are built')
# >>>
from lutil import pax
pax({
"model / len" : len(model),
# "do_train": args.do_train,
})
# from lutil import pax
# pax({
# "model / len" : len(model),
# # "do_train": args.do_train,
# })
# <<<
# Print setup timing.
......@@ -233,11 +233,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
this_model.model_type = model_type
model.append(this_model)
# >>>
from lutil import pax
pax({
"virtual size" : args.virtual_pipeline_model_parallel_size,
"model" : model,
})
# from lutil import pax
# pax({
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# "model" : model,
# })
# <<<
else:
pre_process = mpu.is_pipeline_first_stage()
......@@ -366,8 +366,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model = get_model(model_provider_func, model_type)
# >>>
from lutil import pax
pax({"model": model})
# from lutil import pax
# pax({"model": model})
# <<<
unwrapped_model = unwrap_model(model,
......@@ -938,8 +938,8 @@ def build_train_valid_test_data_iterators(
args.do_test = flags[2].item()
# >>>
from lutil import pax
pax({"hi": "there"})
# from lutil import pax
# pax({"hi": "there"})
# <<<
# Build iterators.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment