Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b8940b96
Commit
b8940b96
authored
Aug 21, 2021
by
mshoeybi
Browse files
added for pp
parent
7f2cc3a4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
megatron/arguments.py
megatron/arguments.py
+1
-1
megatron/model/transformer.py
megatron/model/transformer.py
+2
-2
No files found.
megatron/arguments.py
View file @
b8940b96
...
@@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
tensor_model_parallel_size
>
1
assert
args
.
activations_checkpoint_method
is
not
None
,
\
assert
args
.
activations_checkpoint_method
is
not
None
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to use a valid checkpoint-activation method (
\'
uniform
\'
or
\'
block
\'
)'
'need to use a valid checkpoint-activation method (
\'
uniform
\'
or
\'
block
\'
)'
assert
args
.
pipeline_model_parallel_size
==
1
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/model/transformer.py
View file @
b8940b96
...
@@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
self
.
distribute_checkpointed_activations
and
(
(
l
>
0
)
or
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
))
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
elif
self
.
activations_checkpoint_method
==
'block'
:
...
@@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule):
if
l
<
self
.
activations_checkpoint_num_layers
:
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
self
.
distribute_checkpointed_activations
and
(
(
l
>
0
)
or
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
))
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment