Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c61dc22f
Commit
c61dc22f
authored
Aug 21, 2021
by
mshoeybi
Browse files
some cleanup
parent
b8940b96
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
megatron/arguments.py
megatron/arguments.py
+4
-2
megatron/model/transformer.py
megatron/model/transformer.py
+19
-2
No files found.
megatron/arguments.py
View file @
c61dc22f
...
@@ -240,10 +240,12 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -240,10 +240,12 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
tensor_model_parallel_size
>
1
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'checkpointed activations only across tensor model '
\
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
assert
args
.
activations_checkpoint_method
is
not
None
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to use a
valid checkpoint-activation method (
\'
uniform
\'
or
\'
block
\'
)
'
'need to use a
activation-checkpoint method
'
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/model/transformer.py
View file @
c61dc22f
...
@@ -608,6 +608,23 @@ class ParallelTransformer(MegatronModule):
...
@@ -608,6 +608,23 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
def
distribute_checkpointed_activations_helper
(
layer_number
):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage
=
(
layer_number
>
0
)
is_first_pipeline_stage
=
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
)
return
self
.
distribute_checkpointed_activations
and
\
(
not_first_layer_in_pipeline_stage
or
is_first_pipeline_stage
)
if
self
.
activations_checkpoint_method
==
'uniform'
:
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# the input activation of each divided chunk.
...
@@ -616,7 +633,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -616,7 +633,7 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
and
(
(
l
>
0
)
or
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
)
),
distribute_checkpointed_activations
_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
elif
self
.
activations_checkpoint_method
==
'block'
:
...
@@ -627,7 +644,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -627,7 +644,7 @@ class ParallelTransformer(MegatronModule):
if
l
<
self
.
activations_checkpoint_num_layers
:
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
and
(
(
l
>
0
)
or
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
)
),
distribute_checkpointed_activations
_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment