Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cfd2e216
Commit
cfd2e216
authored
Apr 28, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
13b3dca6
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
157 additions
and
162 deletions
+157
-162
megatron/arguments.py
megatron/arguments.py
+39
-21
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-1
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+1
-1
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+5
-2
megatron/model/language_model.py
megatron/model/language_model.py
+11
-12
megatron/model/t5_model.py
megatron/model/t5_model.py
+4
-1
megatron/model/transformer.py
megatron/model/transformer.py
+38
-61
megatron/mpu/layers.py
megatron/mpu/layers.py
+34
-41
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+13
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+0
-12
megatron/p2p_communication.py
megatron/p2p_communication.py
+3
-3
megatron/schedules.py
megatron/schedules.py
+3
-3
megatron/training.py
megatron/training.py
+1
-3
pretrain_gpt.py
pretrain_gpt.py
+1
-1
No files found.
megatron/arguments.py
View file @
cfd2e216
...
...
@@ -103,12 +103,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
args
.
checkpoint_granularity
=
'full'
args
.
checkpoint_method
=
'uniform'
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
'use --
activation
-checkpoint-method instead. '
'Defaulting to
activation-
checkpoint-method=uniform.'
)
'use --
checkpoint-granularity and -
-checkpoint-method
instead. '
'Defaulting to
checkpoint-granularity=full and
checkpoint-method=uniform.'
)
del
args
.
checkpoint_activations
# Set input defaults.
...
...
@@ -283,18 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'checkpointed activations only across tensor model '
\
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
assert
args
.
checkpoint_granularity
==
'full'
,
\
'distributed checkpoint activations is only '
\
'application to full checkpoint granularity'
assert
args
.
checkpoint_method
is
not
None
,
\
'for distributed checkpoint activations to work you '
\
'need to use a
activation-
checkpoint method '
'need to use a checkpoint method '
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
'distributed checkpoint activations are supported for pytorch '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
# model parallel memory optmization
if
args
.
model_parallel_memory_opt
:
assert
not
args
.
async_tensor_model_parallel_allreduce
if
args
.
checkpoint_granularity
==
'selective'
:
assert
args
.
checkpoint_method
is
None
,
\
'checkpoint method is not yet supported for '
\
'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
_print_args
(
args
)
return
args
...
...
@@ -476,30 +486,38 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--checkpoint-attention'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--checkpoint-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activatins to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is reverse checkpointed, '
'2) selective: core attention part of the transformer '
'layer is reverse checkpointed.'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
group
.
add_argument
(
'--
activations-
checkpoint-method'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--checkpoint-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'each divided chunk
at specified granularity
, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'rest without any checkpointing
at specified granularity
'
'default) do not apply activations checkpoint to any layers'
)
group
.
add_argument
(
'--
activations-
checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
# deprecated
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
...
...
@@ -548,8 +566,8 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--
model
-parallel
-memory-opt
'
,
action
=
'store_true'
,
help
=
'Enable
model
parallel
memory
optmization.'
)
group
.
add_argument
(
'--
sequence
-parallel'
,
action
=
'store_true'
,
help
=
'Enable
sequence
parallel optmization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
...
...
megatron/model/bert_model.py
View file @
cfd2e216
...
...
@@ -110,8 +110,11 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits
=
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
(),
binary_logits
else
:
# [b s] => [s b]
lm_logits
=
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
...
...
megatron/model/biencoder_model.py
View file @
cfd2e216
...
...
@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[
:
,
0
,
:]
pooled_output
=
lm_output
[
0
,
:
,
:]
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
...
...
megatron/model/gpt_model.py
View file @
cfd2e216
...
...
@@ -32,15 +32,18 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output
,
fp16_lm_cross_entropy
):
# Output.
# Output.
Format [s b h]
output
=
parallel_lm_logits
(
lm_output
,
logit_weights
,
parallel_output
)
if
labels
is
None
:
return
output
# [s b h] => [b s h]
return
output
.
transpose
(
0
,
1
).
contiguous
()
else
:
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
...
...
megatron/model/language_model.py
View file @
cfd2e216
...
...
@@ -33,11 +33,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
args
=
get_args
()
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model
_parallel
_memory_opt
:
args
.
sequence
_parallel
:
input_parallel
=
input_
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
and
not
args
.
model
_parallel
_memory_opt
model_parallel
and
not
args
.
sequence
_parallel
else
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
...
...
@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
args
.
model
_parallel
_memory_opt
)
async_grad_allreduce
,
args
.
sequence
_parallel
)
# Gather if needed.
if
parallel_output
:
...
...
@@ -107,9 +107,9 @@ class Pooler(MegatronModule):
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
# sequence_index: index of the token to pool.
pooled
=
hidden_states
[
:,
sequence_index
,
:]
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
self
.
dense
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
return
pooled
...
...
@@ -171,7 +171,7 @@ class Embedding(MegatronModule):
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
...
...
@@ -214,18 +214,17 @@ class Embedding(MegatronModule):
assert
self
.
tokentype_embeddings
is
None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
embeddings
=
embeddings
.
float
()
if
self
.
model
_parallel
_memory_opt
:
if
self
.
sequence
_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
# Dropout.
if
self
.
model
_parallel
_memory_opt
:
if
self
.
sequence
_parallel
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
...
...
megatron/model/t5_model.py
View file @
cfd2e216
...
...
@@ -157,8 +157,11 @@ class T5Model(MegatronModule):
self
.
word_embeddings_weight
())
if
lm_labels
is
None
:
return
lm_logits
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
else
:
# [b s] => [s b]
lm_labels
=
lm_lables
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
...
...
megatron/model/transformer.py
View file @
cfd2e216
...
...
@@ -15,6 +15,7 @@
"""Transformer."""
import
math
import
contextlib
import
torch
import
torch.nn.functional
as
F
...
...
@@ -27,7 +28,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
_MATMUL_INPUT
=
None
""" We use the following notation throughout this file:
h: hidden size
...
...
@@ -167,6 +167,8 @@ class SwitchMLP(MegatronModule):
class
CoreAttention
(
MegatronModule
):
matmul_input
=
None
def
__init__
(
self
,
layer_number
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
CoreAttention
,
self
).
__init__
()
...
...
@@ -180,7 +182,7 @@ class CoreAttention(MegatronModule):
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
...
@@ -193,15 +195,6 @@ class CoreAttention(MegatronModule):
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
global
_MATMUL_INPUT
if
_MATMUL_INPUT
is
None
:
_MATMUL_INPUT
=
torch
.
empty
(
args
.
micro_batch_size
*
self
.
num_attention_heads_per_partition
,
args
.
seq_length
,
args
.
seq_length
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
())
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
...
...
@@ -220,7 +213,7 @@ class CoreAttention(MegatronModule):
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
def
forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
...
...
@@ -241,20 +234,18 @@ class CoreAttention(MegatronModule):
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
#matmul_result = torch.empty(
# output_size[0]*output_size[1],
# output_size[2],
# output_size[3],
# dtype=query_layer.dtype,
# device=torch.cuda.current_device())
global
_MATMUL_INPUT
matmul_input
=
_MATMUL_INPUT
# preallocting input tensor: [b * np, sq, sk]
if
CoreAttention
.
matmul_input
is
None
:
CoreAttention
.
matmul_input
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_input
,
CoreAttention
.
matmul_input
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
...
...
@@ -273,7 +264,7 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
model
_parallel
_memory_opt
:
if
not
self
.
sequence
_parallel
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
...
...
@@ -334,8 +325,6 @@ class ParallelAttention(MegatronModule):
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
self
.
checkpoint_attention
=
args
.
checkpoint_attention
#assert args.activations_checkpoint_method is None
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
...
@@ -369,6 +358,7 @@ class ParallelAttention(MegatronModule):
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
args
.
checkpoint_granularity
==
'selective'
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
...
...
@@ -491,7 +481,7 @@ class ParallelAttention(MegatronModule):
# core attention computation
# ==================================
if
self
.
checkpoint_attention
:
if
self
.
checkpoint_
core_
attention
:
context_layer
=
self
.
_checkpointed_attention_forward
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
...
...
@@ -564,7 +554,7 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
...
...
@@ -582,7 +572,7 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
...
...
@@ -595,7 +585,7 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
# MLP
if
args
.
num_experts
is
not
None
:
...
...
@@ -747,12 +737,13 @@ class ParallelTransformer(MegatronModule):
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
checkpoint_granularity
=
args
.
checkpoint_granularity
self
.
checkpoint_method
=
args
.
checkpoint_method
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
\
args
.
distribute_checkpointed_activations
and
not
args
.
model
_parallel
_memory_opt
args
.
distribute_checkpointed_activations
and
not
args
.
sequence
_parallel
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
# Number of layers.
self
.
num_layers
=
mpu
.
get_num_layers
(
...
...
@@ -822,7 +813,7 @@ class ParallelTransformer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
...
...
@@ -842,24 +833,24 @@ class ParallelTransformer(MegatronModule):
return
x_
return
custom_forward
if
self
.
activations_
checkpoint_method
==
'uniform'
:
if
self
.
checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_
checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_
checkpoint_num_layers
l
+=
self
.
checkpoint_num_layers
elif
self
.
activations_
checkpoint_method
==
'block'
:
elif
self
.
checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_
checkpoint_num_layers
:
if
l
<
self
.
checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
...
...
@@ -887,7 +878,7 @@ class ParallelTransformer(MegatronModule):
inference_params
=
None
):
# Checks.
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
assert
self
.
checkpoint_granularity
is
None
,
\
'inference does not work with activation checkpointing'
if
not
self
.
pre_process
:
...
...
@@ -915,28 +906,14 @@ class ParallelTransformer(MegatronModule):
keep_graph
=
True
,
)
if
self
.
model_parallel_memory_opt
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
total
=
0
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
if
self
.
sequence_parallel
:
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
contextlib
.
nullcontext
with
rng_context
:
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
if
self
.
checkpoint_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
...
...
megatron/mpu/layers.py
View file @
cfd2e216
...
...
@@ -45,9 +45,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
_TOTAL_INPUT
=
None
_SUB_GRAD_INPUT
=
None
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
...
...
@@ -208,28 +205,32 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
all_gather_buffer
=
None
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
model
_parallel
_memory_opt
):
async_grad_allreduce
,
sequence
_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
model
_parallel
_memory_opt
=
model_parallel_memory_opt
ctx
.
sequence
_parallel
=
sequence_parallel
if
model
_parallel
_memory_opt
:
if
sequence
_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global
_TOTAL_INPUT
total_input
=
_TOTAL_INPUT
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
())
if
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
is
None
:
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
=
\
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
torch
.
distributed
.
_all_gather_base
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
total_input
=
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
else
:
total_input
=
input
...
...
@@ -244,27 +245,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global
_TOTAL_INPUT
total_input
=
_TOTAL_INPUT
handle
=
torch
.
distributed
.
_all_gather_base
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
handle
=
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
total_input
=
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
...
...
@@ -281,7 +280,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
...
...
@@ -303,7 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
...
...
@@ -390,34 +389,28 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
async_tensor_model_parallel_allreduce
=
(
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
model
_parallel
_memory_opt
=
(
args
.
model
_parallel
_memory_opt
and
self
.
sequence
_parallel
=
(
args
.
sequence
_parallel
and
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
model
_parallel
_memory_opt
not
self
.
sequence
_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
global
_TOTAL_INPUT
if
_TOTAL_INPUT
is
None
:
_TOTAL_INPUT
=
torch
.
empty
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
or
\
self
.
model
_parallel
_memory_opt
:
self
.
sequence
_parallel
:
input_parallel
=
input_
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
,
self
.
model
_parallel
_memory_opt
)
self
.
async_tensor_model_parallel_allreduce
,
self
.
sequence
_parallel
)
if
self
.
gather_output
:
# All-gather across the partitions.
assert
not
self
.
model
_parallel
_memory_opt
assert
not
self
.
sequence
_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
...
...
@@ -498,14 +491,14 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
model
_parallel
_memory_opt
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
sequence
_parallel
)
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
...
...
@@ -515,14 +508,14 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
assert
not
self
.
model
_parallel
_memory_opt
assert
not
self
.
sequence
_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
,
None
)
# All-reduce across all the partitions.
if
self
.
model
_parallel
_memory_opt
:
if
self
.
sequence
_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
...
...
megatron/optimizer/__init__.py
View file @
cfd2e216
...
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
...
...
@@ -90,6 +91,18 @@ def get_megatron_optimizer(model,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
# preallocating state tensors to avoid fragmentation
for
param_group
in
optimizer
.
param_groups
:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
state
=
optimizer
.
state
[
param
]
if
len
(
state
)
==
0
:
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
param
.
data
,
dtype
=
torch
.
float
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
param
.
data
,
dtype
=
torch
.
float
)
elif
args
.
optimizer
==
'sgd'
:
optimizer
=
SGD
(
param_groups
,
lr
=
args
.
lr
,
...
...
megatron/optimizer/optimizer.py
View file @
cfd2e216
...
...
@@ -264,14 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
...
...
@@ -289,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
...
...
megatron/p2p_communication.py
View file @
cfd2e216
...
...
@@ -62,7 +62,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model
_parallel
_memory_opt
:
not
args
.
sequence
_parallel
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
...
...
@@ -95,7 +95,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model
_parallel
_memory_opt
:
not
args
.
sequence
_parallel
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
...
@@ -141,7 +141,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model
_parallel
_memory_opt
:
not
args
.
sequence
_parallel
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
...
megatron/schedules.py
View file @
cfd2e216
...
...
@@ -279,7 +279,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
if
args
.
model
_parallel
_memory_opt
:
if
args
.
sequence
_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
...
...
@@ -519,13 +519,13 @@ def get_tensor_shapes(rank, model_type):
args
=
get_args
()
tensor_shapes
=
[]
if
args
.
model
_parallel
_memory_opt
:
if
args
.
sequence
_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
args
.
model
_parallel
_memory_opt
:
if
args
.
sequence
_parallel
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
decoder_seq_length
=
args
.
decoder_seq_length
...
...
megatron/training.py
View file @
cfd2e216
...
...
@@ -422,7 +422,7 @@ def train_step(forward_step_func, data_iterator,
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
args
.
model
_parallel
_memory_opt
:
args
.
sequence
_parallel
:
grads
=
[]
for
model_module
in
model
:
unwrapped_model
=
unwrap_model
(
...
...
@@ -432,8 +432,6 @@ def train_step(forward_step_func, data_iterator,
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
#TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
...
...
pretrain_gpt.py
View file @
cfd2e216
...
...
@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack.
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
transpose
(
0
,
1
).
contiguous
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
# Get the masks and postition ids.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment