Commit a06af061 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added args.transformer_pipeline_model_parallel_size

parent c2b7d0b3
...@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min( args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size)) (args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_size
)
# Checks. # Checks.
model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
...@@ -141,11 +146,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -141,11 +146,6 @@ def parse_args(extra_args_provider=None, defaults={},
# (args.num_layers // args.pipeline_model_parallel_size) // \ # (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage # args.num_layers_per_virtual_pipeline_stage
# <<< # <<<
transformer_pipeline_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_size
)
args.virtual_pipeline_model_parallel_size = \ args.virtual_pipeline_model_parallel_size = \
(args.num_layers // transformer_pipeline_size) // \ (args.num_layers // transformer_pipeline_size) // \
args.num_layers_per_virtual_pipeline_stage args.num_layers_per_virtual_pipeline_stage
......
...@@ -343,11 +343,13 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -343,11 +343,13 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.num_layers // num_ranks_in_decoder
else: else:
transformer_pipeline_size = ( # >>>
get_pipeline_model_parallel_world_size() - 1 # transformer_pipeline_size = (
if args.standalone_embed_stage else # get_pipeline_model_parallel_world_size() - 1
get_pipeline_model_parallel_world_size() # if args.standalone_embed_stage else
) # get_pipeline_model_parallel_world_size()
# )
# <<<
assert args.num_layers % transformer_pipeline_size == 0, \ assert args.num_layers % transformer_pipeline_size == 0, \
'num_layers must be divisible by transformer_pipeline_size' 'num_layers must be divisible by transformer_pipeline_size'
num_layers = ( num_layers = (
...@@ -359,15 +361,15 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -359,15 +361,15 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else: else:
num_layers = args.num_layers num_layers = args.num_layers
# >>> # >>>
from lutil import pax # from lutil import pax
pax(0, { # pax(7, {
"rank" : torch.distributed.get_rank(), # "rank" : torch.distributed.get_rank(),
"pipeline rank" : "%d / %d" % ( # "pipeline rank" : "%d / %d" % (
get_pipeline_model_parallel_rank(), # get_pipeline_model_parallel_rank(),
get_pipeline_model_parallel_world_size(), # get_pipeline_model_parallel_world_size(),
), # ),
"num_layers" : num_layers, # "num_layers" : num_layers,
}) # })
# <<< # <<<
return num_layers return num_layers
......
...@@ -33,6 +33,13 @@ def get_forward_backward_func(): ...@@ -33,6 +33,13 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
# >>>
# from lutil import pax
# pax({
# "num microbatches" : get_num_microbatches(),
# "pipeline size" : args.pipeline_model_parallel_size,
# })
# <<<
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \ 'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule' 'size when using interleaved schedule'
......
...@@ -137,11 +137,11 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -137,11 +137,11 @@ def pretrain(train_valid_test_dataset_provider,
print_datetime('after dataloaders are built') print_datetime('after dataloaders are built')
# >>> # >>>
from lutil import pax # from lutil import pax
pax({ # pax({
"model / len" : len(model), # "model / len" : len(model),
# "do_train": args.do_train, # # "do_train": args.do_train,
}) # })
# <<< # <<<
# Print setup timing. # Print setup timing.
...@@ -233,11 +233,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -233,11 +233,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
this_model.model_type = model_type this_model.model_type = model_type
model.append(this_model) model.append(this_model)
# >>> # >>>
from lutil import pax # from lutil import pax
pax({ # pax({
"virtual size" : args.virtual_pipeline_model_parallel_size, # "virtual size" : args.virtual_pipeline_model_parallel_size,
"model" : model, # "model" : model,
}) # })
# <<< # <<<
else: else:
pre_process = mpu.is_pipeline_first_stage() pre_process = mpu.is_pipeline_first_stage()
...@@ -366,8 +366,8 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -366,8 +366,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
# >>> # >>>
from lutil import pax # from lutil import pax
pax({"model": model}) # pax({"model": model})
# <<< # <<<
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
...@@ -938,8 +938,8 @@ def build_train_valid_test_data_iterators( ...@@ -938,8 +938,8 @@ def build_train_valid_test_data_iterators(
args.do_test = flags[2].item() args.do_test = flags[2].item()
# >>> # >>>
from lutil import pax # from lutil import pax
pax({"hi": "there"}) # pax({"hi": "there"})
# <<< # <<<
# Build iterators. # Build iterators.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment