Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
33dc8e9c
Commit
33dc8e9c
authored
Jan 24, 2022
by
Lawrence McAfee
Browse files
working when no interleaving
parent
9a8b89ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
3 deletions
+34
-3
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+30
-3
No files found.
megatron/arguments.py
View file @
33dc8e9c
...
...
@@ -684,6 +684,10 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--deallocate-pipeline-outputs'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, pipeline output tensors '
'are deallocated during the forward pass.'
)
group
.
add_argument
(
'--standalone-embed-stage'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers.'
)
return
parser
...
...
megatron/mpu/initialize.py
View file @
33dc8e9c
...
...
@@ -269,6 +269,9 @@ def set_tensor_model_parallel_world_size(world_size):
def
set_pipeline_model_parallel_world_size
(
world_size
):
# >>>
raise
Exception
(
"hi."
)
# <<<
"""Set the pipeline model parallel size"""
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
...
...
@@ -287,6 +290,9 @@ def get_pipeline_model_parallel_world_size():
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
# >>>
# raise Exception("hi.")
# <<<
return
torch
.
distributed
.
get_world_size
(
group
=
get_pipeline_model_parallel_group
())
...
...
@@ -322,6 +328,9 @@ def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
# >>>
raise
Exception
(
"fix for t5."
)
# <<<
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
num_ranks_in_encoder
=
args
.
pipeline_model_parallel_split_rank
num_ranks_in_decoder
=
get_pipeline_model_parallel_world_size
()
-
num_ranks_in_encoder
...
...
@@ -334,9 +343,27 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else
:
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
else
:
assert
args
.
num_layers
%
get_pipeline_model_parallel_world_size
()
==
0
,
\
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers
=
args
.
num_layers
//
get_pipeline_model_parallel_world_size
()
transformer_pipeline_size
=
(
get_pipeline_model_parallel_world_size
()
-
1
if
args
.
standalone_embed_stage
else
get_pipeline_model_parallel_world_size
()
)
assert
args
.
num_layers
%
transformer_pipeline_size
==
0
,
\
'num_layers must be divisible by transformer_pipeline_size'
num_layers
=
(
0
if
args
.
standalone_embed_stage
and
get_pipeline_model_parallel_rank
()
==
0
else
args
.
num_layers
//
transformer_pipeline_size
)
# >>>
# from lutil import pax
# pax({
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : get_pipeline_model_parallel_rank(),
# "num_layers" : num_layers,
# })
# <<<
else
:
num_layers
=
args
.
num_layers
return
num_layers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment