Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a06af061
Commit
a06af061
authored
Jan 24, 2022
by
Lawrence McAfee
Browse files
added args.transformer_pipeline_model_parallel_size
parent
c2b7d0b3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
33 deletions
+42
-33
megatron/arguments.py
megatron/arguments.py
+5
-5
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+16
-14
megatron/schedules.py
megatron/schedules.py
+7
-0
megatron/training.py
megatron/training.py
+14
-14
No files found.
megatron/arguments.py
View file @
a06af061
...
@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
args
.
transformer_pipeline_model_parallel_size
=
(
args
.
pipeline_model_parallel_size
-
1
if
args
.
standalone_embed_stage
else
args
.
pipeline_model_parallel_size
)
# Checks.
# Checks.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
args
.
tensor_model_parallel_size
...
@@ -141,11 +146,6 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -141,11 +146,6 @@ def parse_args(extra_args_provider=None, defaults={},
# (args.num_layers // args.pipeline_model_parallel_size) // \
# (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage
# args.num_layers_per_virtual_pipeline_stage
# <<<
# <<<
transformer_pipeline_size
=
(
args
.
pipeline_model_parallel_size
-
1
if
args
.
standalone_embed_stage
else
args
.
pipeline_model_parallel_size
)
args
.
virtual_pipeline_model_parallel_size
=
\
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
transformer_pipeline_size
)
//
\
(
args
.
num_layers
//
transformer_pipeline_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
args
.
num_layers_per_virtual_pipeline_stage
...
...
megatron/mpu/initialize.py
View file @
a06af061
...
@@ -343,11 +343,13 @@ def get_num_layers(args, is_encoder_and_decoder_model):
...
@@ -343,11 +343,13 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else
:
else
:
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
else
:
else
:
transformer_pipeline_size
=
(
# >>>
get_pipeline_model_parallel_world_size
()
-
1
# transformer_pipeline_size = (
if
args
.
standalone_embed_stage
else
# get_pipeline_model_parallel_world_size() - 1
get_pipeline_model_parallel_world_size
()
# if args.standalone_embed_stage else
)
# get_pipeline_model_parallel_world_size()
# )
# <<<
assert
args
.
num_layers
%
transformer_pipeline_size
==
0
,
\
assert
args
.
num_layers
%
transformer_pipeline_size
==
0
,
\
'num_layers must be divisible by transformer_pipeline_size'
'num_layers must be divisible by transformer_pipeline_size'
num_layers
=
(
num_layers
=
(
...
@@ -359,15 +361,15 @@ def get_num_layers(args, is_encoder_and_decoder_model):
...
@@ -359,15 +361,15 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else
:
else
:
num_layers
=
args
.
num_layers
num_layers
=
args
.
num_layers
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
(
0
,
{
#
pax(
7
, {
"rank"
:
torch
.
distributed
.
get_rank
(),
#
"rank" : torch.distributed.get_rank(),
"pipeline rank"
:
"%d / %d"
%
(
#
"pipeline rank" : "%d / %d" % (
get_pipeline_model_parallel_rank
(),
#
get_pipeline_model_parallel_rank(),
get_pipeline_model_parallel_world_size
(),
#
get_pipeline_model_parallel_world_size(),
),
#
),
"num_layers"
:
num_layers
,
#
"num_layers" : num_layers,
})
#
})
# <<<
# <<<
return
num_layers
return
num_layers
...
...
megatron/schedules.py
View file @
a06af061
...
@@ -33,6 +33,13 @@ def get_forward_backward_func():
...
@@ -33,6 +33,13 @@ def get_forward_backward_func():
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
forward_backward_func
=
forward_backward_pipelining_with_interleaving
# >>>
# from lutil import pax
# pax({
# "num microbatches" : get_num_microbatches(),
# "pipeline size" : args.pipeline_model_parallel_size,
# })
# <<<
assert
get_num_microbatches
()
%
args
.
pipeline_model_parallel_size
==
0
,
\
assert
get_num_microbatches
()
%
args
.
pipeline_model_parallel_size
==
0
,
\
'number of microbatches is not divisible by pipeline-parallel '
\
'number of microbatches is not divisible by pipeline-parallel '
\
'size when using interleaved schedule'
'size when using interleaved schedule'
...
...
megatron/training.py
View file @
a06af061
...
@@ -137,11 +137,11 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -137,11 +137,11 @@ def pretrain(train_valid_test_dataset_provider,
print_datetime
(
'after dataloaders are built'
)
print_datetime
(
'after dataloaders are built'
)
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
({
#
pax({
"model / len"
:
len
(
model
),
#
"model / len" : len(model),
# "do_train": args.do_train,
#
# "do_train": args.do_train,
})
#
})
# <<<
# <<<
# Print setup timing.
# Print setup timing.
...
@@ -233,11 +233,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -233,11 +233,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
this_model
.
model_type
=
model_type
this_model
.
model_type
=
model_type
model
.
append
(
this_model
)
model
.
append
(
this_model
)
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
({
#
pax({
"virtual size"
:
args
.
virtual_pipeline_model_parallel_size
,
#
"virtual size" : args.virtual_pipeline_model_parallel_size,
"model"
:
model
,
#
"model" : model,
})
#
})
# <<<
# <<<
else
:
else
:
pre_process
=
mpu
.
is_pipeline_first_stage
()
pre_process
=
mpu
.
is_pipeline_first_stage
()
...
@@ -366,8 +366,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -366,8 +366,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model
=
get_model
(
model_provider_func
,
model_type
)
model
=
get_model
(
model_provider_func
,
model_type
)
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
({
"model"
:
model
})
#
pax({"model": model})
# <<<
# <<<
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
...
@@ -938,8 +938,8 @@ def build_train_valid_test_data_iterators(
...
@@ -938,8 +938,8 @@ def build_train_valid_test_data_iterators(
args
.
do_test
=
flags
[
2
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
({
"hi"
:
"there"
})
#
pax({"hi": "there"})
# <<<
# <<<
# Build iterators.
# Build iterators.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment