Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bbab79f8
Commit
bbab79f8
authored
May 16, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
356eb36a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
48 deletions
+63
-48
megatron/arguments.py
megatron/arguments.py
+39
-32
megatron/model/transformer.py
megatron/model/transformer.py
+20
-16
megatron/mpu/layers.py
megatron/mpu/layers.py
+4
-0
No files found.
megatron/arguments.py
View file @
bbab79f8
...
@@ -105,14 +105,19 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -105,14 +105,19 @@ def parse_args(extra_args_provider=None, defaults={},
del
args
.
model_parallel_size
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
if
args
.
checkpoint_activations
:
args
.
checkpoint
_granularity
=
'full'
args
.
recompute
_granularity
=
'full'
args
.
checkpoint
_method
=
'uniform'
args
.
recompute
_method
=
'uniform'
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
print
(
'--checkpoint-activations is no longer valid, '
'use --
checkpoint
-granularity and --
checkpoint
-method instead. '
'use --
recompute
-granularity and --
recompute
-method instead. '
'Defaulting to
checkpoint
-granularity=full and
checkpoint
-method=uniform.'
)
'Defaulting to
recompute
-granularity=full and
recompute
-method=uniform.'
)
del
args
.
checkpoint_activations
del
args
.
checkpoint_activations
if
args
.
recompute_activations
:
args
.
recompute_granularity
=
'selective'
args
.
recompute_method
=
'uniform'
del
args
.
recompute_activations
# Set input defaults.
# Set input defaults.
for
key
in
defaults
:
for
key
in
defaults
:
# For default to be valid, it should not be provided in the
# For default to be valid, it should not be provided in the
...
@@ -280,26 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -280,26 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True'
)
'Defaulting to no_persist_layer_norm=True'
)
# Activation
checkpoin
ting.
# Activation
recompu
ting.
if
args
.
distribute_
checkpoin
ted_activations
:
if
args
.
distribute_
recompu
ted_activations
:
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'
checkpoin
ted activations only across tensor model '
\
'
recompu
ted activations only across tensor model '
\
'parallel groups'
'parallel groups'
assert
args
.
checkpoint
_granularity
==
'full'
,
\
assert
args
.
recompute
_granularity
==
'full'
,
\
'distributed
checkpoint
activations is only '
\
'distributed
recompute
activations is only '
\
'application to full
checkpoint
granularity'
'application to full
recompute
granularity'
assert
args
.
checkpoint
_method
is
not
None
,
\
assert
args
.
recompute
_method
is
not
None
,
\
'for distributed
checkpoint
activations to work you '
\
'for distributed
recompute
activations to work you '
\
'need to use a
checkpoint
method '
'need to use a
recompute
method '
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
'distributed
checkpoint
activations are supported for pytorch '
\
'distributed
recompute
activations are supported for pytorch '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
if
args
.
checkpoint
_granularity
==
'selective'
:
if
args
.
recompute
_granularity
==
'selective'
:
assert
args
.
checkpoint
_method
is
None
,
\
assert
args
.
recompute
_method
is
None
,
\
'
checkpoint
method is not yet supported for '
\
'
recompute
method is not yet supported for '
\
'selective
checkpoin
ting granularity'
'selective
recompu
ting granularity'
# disable async_tensor_model_parallel_allreduce when
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
# model parallel memory optimization is enabled
...
@@ -486,33 +491,35 @@ def _add_training_args(parser):
...
@@ -486,33 +491,35 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--recompute-activations'
,
action
=
'store_true'
,
group
.
add_argument
(
'--checkpoint-granularity'
,
type
=
str
,
default
=
None
,
help
=
'recompute activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--recompute-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activations to allow for training '
help
=
'Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'It is supported at two granularities 1) full: '
'whole transformer layer is
checkpoin
ted, '
'whole transformer layer is
recompu
ted, '
'2) selective: core attention part of the transformer '
'2) selective: core attention part of the transformer '
'layer is
checkpoin
ted.'
)
'layer is
recompu
ted.'
)
group
.
add_argument
(
'--distribute-
checkpoin
ted-activations'
,
group
.
add_argument
(
'--distribute-
recompu
ted-activations'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute
checkpoin
ted activations '
help
=
'If set, distribute
recompu
ted activations '
'across model parallel group.'
)
'across model parallel group.'
)
group
.
add_argument
(
'--
checkpoint
-method'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--
recompute
-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'uniform'
,
'block'
],
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and
checkpoint
the input activation of '
'Transformer layers and
recompute
the input activation of '
'each divided chunk at specified granularity, '
'each divided chunk at specified granularity, '
'2)
checkpoint
the input activations of only a set number of '
'2)
recompute
the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'individual Transformer layers per pipeline stage and do the '
'rest without any
checkpoin
ting at specified granularity'
'rest without any
recompu
ting at specified granularity'
'default) do not apply activations
checkpoint
to any layers'
)
'default) do not apply activations
recompute
to any layers'
)
group
.
add_argument
(
'--
checkpoint
-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--
recompute
-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided
checkpoint
unit, '
'uniformly divided
recompute
unit, '
'2) block: the number of individual Transformer layers '
'2) block: the number of individual Transformer layers '
'to
checkpoint
within each pipeline stage.'
)
'to
recompute
within each pipeline stage.'
)
# deprecated
# deprecated
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
...
...
megatron/model/transformer.py
View file @
bbab79f8
...
@@ -242,6 +242,10 @@ class CoreAttention(MegatronModule):
...
@@ -242,6 +242,10 @@ class CoreAttention(MegatronModule):
output_size
[
3
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
else
:
assert
CoreAttention
.
matmul_input_buffer
.
size
()
==
\
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
\
"buffer dimensions should remain the same during the training run"
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
=
torch
.
baddbmm
(
...
@@ -358,7 +362,7 @@ class ParallelAttention(MegatronModule):
...
@@ -358,7 +362,7 @@ class ParallelAttention(MegatronModule):
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
attn_mask_type
)
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
args
.
checkpoint
_granularity
==
'selective'
self
.
checkpoint_core_attention
=
args
.
recompute
_granularity
==
'selective'
# Output.
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
self
.
dense
=
mpu
.
RowParallelLinear
(
...
@@ -743,11 +747,11 @@ class ParallelTransformer(MegatronModule):
...
@@ -743,11 +747,11 @@ class ParallelTransformer(MegatronModule):
self
.
drop_path_rate
=
drop_path_rate
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
checkpoint
_granularity
=
args
.
checkpoint
_granularity
self
.
recompute
_granularity
=
args
.
recompute
_granularity
self
.
checkpoint
_method
=
args
.
checkpoint
_method
self
.
recompute
_method
=
args
.
recompute
_method
self
.
checkpoint
_num_layers
=
args
.
checkpoint
_num_layers
self
.
recompute
_num_layers
=
args
.
recompute
_num_layers
self
.
distribute_
checkpoin
ted_activations
=
\
self
.
distribute_
recompu
ted_activations
=
\
args
.
distribute_
checkpoin
ted_activations
and
not
args
.
sequence_parallel
args
.
distribute_
recompu
ted_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
...
@@ -839,33 +843,33 @@ class ParallelTransformer(MegatronModule):
...
@@ -839,33 +843,33 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
if
self
.
checkpoint
_method
==
'uniform'
:
if
self
.
recompute
_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint
_num_layers
),
custom
(
l
,
l
+
self
.
recompute
_num_layers
),
self
.
distribute_
checkpoin
ted_activations
,
self
.
distribute_
recompu
ted_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint
_num_layers
l
+=
self
.
recompute
_num_layers
elif
self
.
checkpoint
_method
==
'block'
:
elif
self
.
recompute
_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
checkpoint
_num_layers
:
if
l
<
self
.
recompute
_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_
checkpoin
ted_activations
,
self
.
distribute_
recompu
ted_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
raise
ValueError
(
"Invalid activation
checkpoint
method."
)
raise
ValueError
(
"Invalid activation
recompute
method."
)
return
hidden_states
return
hidden_states
...
@@ -886,7 +890,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -886,7 +890,7 @@ class ParallelTransformer(MegatronModule):
# Checks.
# Checks.
if
inference_params
:
if
inference_params
:
assert
self
.
checkpoint
_granularity
is
None
,
\
assert
self
.
recompute
_granularity
is
None
,
\
'inference does not work with activation checkpointing'
'inference does not work with activation checkpointing'
if
not
self
.
pre_process
:
if
not
self
.
pre_process
:
...
@@ -921,7 +925,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -921,7 +925,7 @@ class ParallelTransformer(MegatronModule):
with
rng_context
:
with
rng_context
:
# Forward pass.
# Forward pass.
if
self
.
checkpoint
_granularity
==
'full'
:
if
self
.
recompute
_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
,
...
...
megatron/mpu/layers.py
View file @
bbab79f8
...
@@ -226,6 +226,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -226,6 +226,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
requires_grad
=
False
)
else
:
assert
list
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
.
size
())
==
dim_size
,
\
"buffer dimensions should remain same during the training run"
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
input
,
input
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment