Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d19e8ad7
Commit
d19e8ad7
authored
May 16, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
bbab79f8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
13 deletions
+13
-13
megatron/arguments.py
megatron/arguments.py
+2
-2
megatron/model/transformer.py
megatron/model/transformer.py
+4
-4
megatron/mpu/random.py
megatron/mpu/random.py
+7
-7
No files found.
megatron/arguments.py
View file @
d19e8ad7
...
...
@@ -286,7 +286,7 @@ def parse_args(extra_args_provider=None, defaults={},
'Defaulting to no_persist_layer_norm=True'
)
# Activation recomputing.
if
args
.
distribute_
recomput
ed_activations
:
if
args
.
distribute_
sav
ed_activations
:
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'recomputed activations only across tensor model '
\
'parallel groups'
...
...
@@ -502,7 +502,7 @@ def _add_training_args(parser):
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.'
)
group
.
add_argument
(
'--distribute-
recomput
ed-activations'
,
group
.
add_argument
(
'--distribute-
sav
ed-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute recomputed activations '
'across model parallel group.'
)
...
...
megatron/model/transformer.py
View file @
d19e8ad7
...
...
@@ -750,8 +750,8 @@ class ParallelTransformer(MegatronModule):
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
distribute_
recomput
ed_activations
=
\
args
.
distribute_
recomput
ed_activations
and
not
args
.
sequence_parallel
self
.
distribute_
sav
ed_activations
=
\
args
.
distribute_
sav
ed_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
...
...
@@ -851,7 +851,7 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_
recomput
ed_activations
,
self
.
distribute_
sav
ed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
recompute_num_layers
...
...
@@ -863,7 +863,7 @@ class ParallelTransformer(MegatronModule):
if
l
<
self
.
recompute_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_
recomput
ed_activations
,
self
.
distribute_
sav
ed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
...
...
megatron/mpu/random.py
View file @
d19e8ad7
...
...
@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
distribute_
checkpoint
ed_activations
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_
sav
ed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
distribute_
checkpoint
ed_activations
\
=
distribute_
checkpoint
ed_activations
ctx
.
distribute_
sav
ed_activations
\
=
distribute_
sav
ed_activations
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
distribute_
checkpoint
ed_activations
:
if
distribute_
sav
ed_activations
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
safely_set_viewless_tensor_data
(
args
[
0
],
...
...
@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
ctx
.
distribute_
checkpoint
ed_activations
:
if
ctx
.
distribute_
sav
ed_activations
:
safely_set_viewless_tensor_data
(
inputs
[
0
],
gather_split_1d_tensor
(
inputs
[
0
].
data
).
view
(
ctx
.
input_0_shape
))
...
...
@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
distribute_
checkpoint
ed_activations
,
*
args
):
def
checkpoint
(
function
,
distribute_
sav
ed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
distribute_
checkpoint
ed_activations
,
*
args
)
distribute_
sav
ed_activations
,
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment