Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cfd2e216
"vscode:/vscode.git/clone" did not exist on "e731ae0ec817649bf2c15f9f719269d57062696c"
Commit
cfd2e216
authored
Apr 28, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
13b3dca6
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
157 additions
and
162 deletions
+157
-162
megatron/arguments.py
megatron/arguments.py
+39
-21
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-1
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+1
-1
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+5
-2
megatron/model/language_model.py
megatron/model/language_model.py
+11
-12
megatron/model/t5_model.py
megatron/model/t5_model.py
+4
-1
megatron/model/transformer.py
megatron/model/transformer.py
+38
-61
megatron/mpu/layers.py
megatron/mpu/layers.py
+34
-41
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+13
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+0
-12
megatron/p2p_communication.py
megatron/p2p_communication.py
+3
-3
megatron/schedules.py
megatron/schedules.py
+3
-3
megatron/training.py
megatron/training.py
+1
-3
pretrain_gpt.py
pretrain_gpt.py
+1
-1
No files found.
megatron/arguments.py
View file @
cfd2e216
...
@@ -103,12 +103,14 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -103,12 +103,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
'longer valid, use --tensor-model-parallel-size instead'
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
args
.
checkpoint_granularity
=
'full'
args
.
checkpoint_method
=
'uniform'
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
print
(
'--checkpoint-activations is no longer valid, '
'use --
activation
-checkpoint-method instead. '
'use --
checkpoint-granularity and -
-checkpoint-method
instead. '
'Defaulting to
activation-
checkpoint-method=uniform.'
)
'Defaulting to
checkpoint-granularity=full and
checkpoint-method=uniform.'
)
del
args
.
checkpoint_activations
del
args
.
checkpoint_activations
# Set input defaults.
# Set input defaults.
...
@@ -283,18 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -283,18 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'checkpointed activations only across tensor model '
\
'checkpointed activations only across tensor model '
\
'parallel groups'
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
assert
args
.
checkpoint_granularity
==
'full'
,
\
'distributed checkpoint activations is only '
\
'application to full checkpoint granularity'
assert
args
.
checkpoint_method
is
not
None
,
\
'for distributed checkpoint activations to work you '
\
'for distributed checkpoint activations to work you '
\
'need to use a
activation-
checkpoint method '
'need to use a checkpoint method '
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
'distributed checkpoint activations are supported for pytorch '
\
'distributed checkpoint activations are supported for pytorch '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
if
args
.
checkpoint_granularity
==
'selective'
:
assert
args
.
checkpoint_method
is
None
,
\
'checkpoint method is not yet supported for '
\
'selective checkpointing granularity'
# model parallel memory optmization
# disable async_tensor_model_parallel_allreduce when
if
args
.
model_parallel_memory_opt
:
# model parallel memory optmization is enabled
assert
not
args
.
async_tensor_model_parallel_allreduce
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -476,30 +486,38 @@ def _add_training_args(parser):
...
@@ -476,30 +486,38 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
group
.
add_argument
(
'--checkpoint-granularity'
,
type
=
str
,
default
=
None
,
'with larger models, sequences, and batch sizes.'
)
choices
=
[
'full'
,
'selective'
],
group
.
add_argument
(
'--checkpoint-attention'
,
action
=
'store_true'
,
help
=
'Checkpoint activatins to allow for training '
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes. '
'with larger models, sequences, and batch sizes.'
)
'It is supported at two granularities 1) full: '
'whole transformer layer is reverse checkpointed, '
'2) selective: core attention part of the transformer '
'layer is reverse checkpointed.'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
'across model parallel group.'
)
group
.
add_argument
(
'--
activations-
checkpoint-method'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--checkpoint-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'uniform'
,
'block'
],
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'each divided chunk
at specified granularity
, '
'2) checkpoint the input activations of only a set number of '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'rest without any checkpointing
at specified granularity
'
'default) do not apply activations checkpoint to any layers'
)
'default) do not apply activations checkpoint to any layers'
)
group
.
add_argument
(
'--
activations-
checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
'to checkpoint within each pipeline stage.'
)
# deprecated
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'training runs. Note that either train-iters or '
...
@@ -548,8 +566,8 @@ def _add_training_args(parser):
...
@@ -548,8 +566,8 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
'size is supported.'
)
group
.
add_argument
(
'--
model
-parallel
-memory-opt
'
,
action
=
'store_true'
,
group
.
add_argument
(
'--
sequence
-parallel'
,
action
=
'store_true'
,
help
=
'Enable
model
parallel
memory
optmization.'
)
help
=
'Enable
sequence
parallel optmization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
help
=
'Disable fusing gradient accumulation to weight '
...
...
megatron/model/bert_model.py
View file @
cfd2e216
...
@@ -110,8 +110,11 @@ def post_language_model_processing(lm_output, pooled_output,
...
@@ -110,8 +110,11 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits
=
binary_head
(
pooled_output
)
binary_logits
=
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
(),
binary_logits
else
:
else
:
# [b s] => [s b]
lm_logits
=
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
...
...
megatron/model/biencoder_model.py
View file @
cfd2e216
...
@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
...
@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[
:
,
0
,
:]
pooled_output
=
lm_output
[
0
,
:
,
:]
# Converting to float16 dtype
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
...
...
megatron/model/gpt_model.py
View file @
cfd2e216
...
@@ -32,15 +32,18 @@ def post_language_model_processing(lm_output, labels, logit_weights,
...
@@ -32,15 +32,18 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output
,
parallel_output
,
fp16_lm_cross_entropy
):
fp16_lm_cross_entropy
):
# Output.
# Output.
Format [s b h]
output
=
parallel_lm_logits
(
output
=
parallel_lm_logits
(
lm_output
,
lm_output
,
logit_weights
,
logit_weights
,
parallel_output
)
parallel_output
)
if
labels
is
None
:
if
labels
is
None
:
return
output
# [s b h] => [b s h]
return
output
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
assert
output
.
dtype
==
torch
.
half
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
...
...
megatron/model/language_model.py
View file @
cfd2e216
...
@@ -33,11 +33,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -33,11 +33,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
args
=
get_args
()
args
=
get_args
()
# Parallel logits.
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model
_parallel
_memory_opt
:
args
.
sequence
_parallel
:
input_parallel
=
input_
input_parallel
=
input_
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
and
not
args
.
model
_parallel
_memory_opt
model_parallel
and
not
args
.
sequence
_parallel
else
:
else
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
async_grad_allreduce
=
False
...
@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
args
.
model
_parallel
_memory_opt
)
async_grad_allreduce
,
args
.
sequence
_parallel
)
# Gather if needed.
# Gather if needed.
if
parallel_output
:
if
parallel_output
:
...
@@ -107,9 +107,9 @@ class Pooler(MegatronModule):
...
@@ -107,9 +107,9 @@ class Pooler(MegatronModule):
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
# sequence_index: index of the token to pool.
# sequence_index: index of the token to pool.
pooled
=
hidden_states
[
:,
sequence_index
,
:]
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
self
.
dense
(
pooled
)
pooled
=
self
.
dense
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
return
pooled
return
pooled
...
@@ -171,7 +171,7 @@ class Embedding(MegatronModule):
...
@@ -171,7 +171,7 @@ class Embedding(MegatronModule):
self
.
tokentype_embeddings
=
None
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
# Embeddings dropout
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
...
@@ -214,18 +214,17 @@ class Embedding(MegatronModule):
...
@@ -214,18 +214,17 @@ class Embedding(MegatronModule):
assert
self
.
tokentype_embeddings
is
None
assert
self
.
tokentype_embeddings
is
None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
().
float
()
embeddings
=
embeddings
.
float
()
# Otherwise, leave it as is.
else
:
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model
_parallel
_memory_opt
:
if
self
.
sequence
_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
# Dropout.
# Dropout.
if
self
.
model
_parallel
_memory_opt
:
if
self
.
sequence
_parallel
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
else
:
...
...
megatron/model/t5_model.py
View file @
cfd2e216
...
@@ -157,8 +157,11 @@ class T5Model(MegatronModule):
...
@@ -157,8 +157,11 @@ class T5Model(MegatronModule):
self
.
word_embeddings_weight
())
self
.
word_embeddings_weight
())
if
lm_labels
is
None
:
if
lm_labels
is
None
:
return
lm_logits
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
# [b s] => [s b]
lm_labels
=
lm_lables
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
...
...
megatron/model/transformer.py
View file @
cfd2e216
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Transformer."""
"""Transformer."""
import
math
import
math
import
contextlib
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -27,7 +28,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...
@@ -27,7 +28,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
_MATMUL_INPUT
=
None
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
...
@@ -167,6 +167,8 @@ class SwitchMLP(MegatronModule):
...
@@ -167,6 +167,8 @@ class SwitchMLP(MegatronModule):
class
CoreAttention
(
MegatronModule
):
class
CoreAttention
(
MegatronModule
):
matmul_input
=
None
def
__init__
(
self
,
layer_number
,
def
__init__
(
self
,
layer_number
,
attn_mask_type
=
AttnMaskType
.
padding
):
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
CoreAttention
,
self
).
__init__
()
super
(
CoreAttention
,
self
).
__init__
()
...
@@ -180,7 +182,7 @@ class CoreAttention(MegatronModule):
...
@@ -180,7 +182,7 @@ class CoreAttention(MegatronModule):
self
.
attention_softmax_in_fp32
=
True
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -193,15 +195,6 @@ class CoreAttention(MegatronModule):
...
@@ -193,15 +195,6 @@ class CoreAttention(MegatronModule):
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
args
.
num_attention_heads
,
world_size
)
global
_MATMUL_INPUT
if
_MATMUL_INPUT
is
None
:
_MATMUL_INPUT
=
torch
.
empty
(
args
.
micro_batch_size
*
self
.
num_attention_heads_per_partition
,
args
.
seq_length
,
args
.
seq_length
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
())
coeff
=
None
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
if
self
.
apply_query_key_layer_scaling
:
...
@@ -241,20 +234,18 @@ class CoreAttention(MegatronModule):
...
@@ -241,20 +234,18 @@ class CoreAttention(MegatronModule):
key_layer
=
key_layer
.
view
(
output_size
[
3
],
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
# preallocting input tensor: [b * np, sq, sk]
#matmul_result = torch.empty(
if
CoreAttention
.
matmul_input
is
None
:
# output_size[0]*output_size[1],
CoreAttention
.
matmul_input
=
torch
.
empty
(
# output_size[2],
output_size
[
0
]
*
output_size
[
1
],
# output_size[3],
output_size
[
2
],
# dtype=query_layer.dtype,
output_size
[
3
],
# device=torch.cuda.current_device())
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
global
_MATMUL_INPUT
matmul_input
=
_MATMUL_INPUT
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
=
torch
.
baddbmm
(
matmul_input
,
CoreAttention
.
matmul_input
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
...
@@ -273,7 +264,7 @@ class CoreAttention(MegatronModule):
...
@@ -273,7 +264,7 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
model
_parallel
_memory_opt
:
if
not
self
.
sequence
_parallel
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
else
:
...
@@ -334,8 +325,6 @@ class ParallelAttention(MegatronModule):
...
@@ -334,8 +325,6 @@ class ParallelAttention(MegatronModule):
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
self
.
params_dtype
=
args
.
params_dtype
self
.
checkpoint_attention
=
args
.
checkpoint_attention
#assert args.activations_checkpoint_method is None
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -369,6 +358,7 @@ class ParallelAttention(MegatronModule):
...
@@ -369,6 +358,7 @@ class ParallelAttention(MegatronModule):
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
attn_mask_type
)
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
args
.
checkpoint_granularity
==
'selective'
# Output.
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
self
.
dense
=
mpu
.
RowParallelLinear
(
...
@@ -491,7 +481,7 @@ class ParallelAttention(MegatronModule):
...
@@ -491,7 +481,7 @@ class ParallelAttention(MegatronModule):
# core attention computation
# core attention computation
# ==================================
# ==================================
if
self
.
checkpoint_attention
:
if
self
.
checkpoint_
core_
attention
:
context_layer
=
self
.
_checkpointed_attention_forward
(
context_layer
=
self
.
_checkpointed_attention_forward
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
else
:
...
@@ -564,7 +554,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -564,7 +554,7 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
...
@@ -582,7 +572,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -582,7 +572,7 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
...
@@ -595,7 +585,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -595,7 +585,7 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
# MLP
# MLP
if
args
.
num_experts
is
not
None
:
if
args
.
num_experts
is
not
None
:
...
@@ -747,12 +737,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -747,12 +737,13 @@ class ParallelTransformer(MegatronModule):
self
.
drop_path_rate
=
drop_path_rate
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
checkpoint_granularity
=
args
.
checkpoint_granularity
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
checkpoint_method
=
args
.
checkpoint_method
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
\
self
.
distribute_checkpointed_activations
=
\
args
.
distribute_checkpointed_activations
and
not
args
.
model
_parallel
_memory_opt
args
.
distribute_checkpointed_activations
and
not
args
.
sequence
_parallel
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
# Number of layers.
# Number of layers.
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
...
@@ -822,7 +813,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -822,7 +813,7 @@ class ParallelTransformer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model
_parallel
_memory_opt
)
sequence_parallel
=
args
.
sequence
_parallel
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
...
@@ -842,24 +833,24 @@ class ParallelTransformer(MegatronModule):
...
@@ -842,24 +833,24 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
if
self
.
activations_
checkpoint_method
==
'uniform'
:
if
self
.
checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_
checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_
checkpoint_num_layers
l
+=
self
.
checkpoint_num_layers
elif
self
.
activations_
checkpoint_method
==
'block'
:
elif
self
.
checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_
checkpoint_num_layers
:
if
l
<
self
.
checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
self
.
distribute_checkpointed_activations
,
...
@@ -887,7 +878,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -887,7 +878,7 @@ class ParallelTransformer(MegatronModule):
inference_params
=
None
):
inference_params
=
None
):
# Checks.
# Checks.
if
inference_params
:
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
assert
self
.
checkpoint_granularity
is
None
,
\
'inference does not work with activation checkpointing'
'inference does not work with activation checkpointing'
if
not
self
.
pre_process
:
if
not
self
.
pre_process
:
...
@@ -915,28 +906,14 @@ class ParallelTransformer(MegatronModule):
...
@@ -915,28 +906,14 @@ class ParallelTransformer(MegatronModule):
keep_graph
=
True
,
keep_graph
=
True
,
)
)
if
self
.
model_parallel_memory_opt
:
if
self
.
sequence_parallel
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
total
=
0
rng_context
=
contextlib
.
nullcontext
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
with
rng_context
:
# Forward pass.
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
if
self
.
checkpoint_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
,
...
...
megatron/mpu/layers.py
View file @
cfd2e216
...
@@ -45,9 +45,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
...
@@ -45,9 +45,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim'
:
-
1
,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
'partition_stride'
:
1
}
_TOTAL_INPUT
=
None
_SUB_GRAD_INPUT
=
None
def
param_is_not_tensor_parallel_duplicate
(
param
):
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
param
.
tensor_model_parallel
)
or
(
...
@@ -208,28 +205,32 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -208,28 +205,32 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
fusion in backprop.
"""
"""
all_gather_buffer
=
None
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
model
_parallel
_memory_opt
):
async_grad_allreduce
,
sequence
_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
model
_parallel
_memory_opt
=
model_parallel_memory_opt
ctx
.
sequence
_parallel
=
sequence_parallel
if
model
_parallel
_memory_opt
:
if
sequence
_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
#total_input = torch.empty(dim_size, dtype=input.dtype,
if
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
is
None
:
# device=torch.cuda.current_device(),
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
=
\
# requires_grad=False)
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
global
_TOTAL_INPUT
device
=
torch
.
cuda
.
current_device
(),
total_input
=
_TOTAL_INPUT
requires_grad
=
False
)
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
torch
.
distributed
.
_all_gather_base
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
group
=
get_tensor_model_parallel_group
())
total_input
=
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
else
:
else
:
total_input
=
input
total_input
=
input
...
@@ -244,27 +245,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -244,27 +245,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
#total_input = torch.empty(dim_size, dtype=input.dtype,
handle
=
torch
.
distributed
.
_all_gather_base
(
# device=torch.cuda.current_device(),
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
# requires_grad=False)
input
,
global
_TOTAL_INPUT
total_input
=
_TOTAL_INPUT
handle
=
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
total_input
=
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
else
:
else
:
total_input
=
input
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
handle
.
wait
()
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
# Convert the tensor shapes to 2D for execution compatibility
...
@@ -281,7 +280,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -281,7 +280,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
assert
not
ctx
.
async_grad_allreduce
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
...
@@ -303,7 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -303,7 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
model
_parallel
_memory_opt
:
if
ctx
.
sequence
_parallel
:
handle
.
wait
()
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
...
@@ -390,34 +389,28 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -390,34 +389,28 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
async_tensor_model_parallel_allreduce
=
(
self
.
async_tensor_model_parallel_allreduce
=
(
args
.
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
world_size
>
1
)
self
.
model
_parallel
_memory_opt
=
(
self
.
sequence
_parallel
=
(
args
.
model
_parallel
_memory_opt
and
args
.
sequence
_parallel
and
world_size
>
1
)
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
model
_parallel
_memory_opt
not
self
.
sequence
_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
global
_TOTAL_INPUT
if
_TOTAL_INPUT
is
None
:
_TOTAL_INPUT
=
torch
.
empty
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
or
\
if
self
.
async_tensor_model_parallel_allreduce
or
\
self
.
model
_parallel
_memory_opt
:
self
.
sequence
_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
,
self
.
model
_parallel
_memory_opt
)
self
.
async_tensor_model_parallel_allreduce
,
self
.
sequence
_parallel
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
assert
not
self
.
model
_parallel
_memory_opt
assert
not
self
.
sequence
_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
@@ -498,14 +491,14 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -498,14 +491,14 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
dtype
=
args
.
params_dtype
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
model
_parallel
_memory_opt
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
sequence
_parallel
)
# Always initialize bias to zero.
# Always initialize bias to zero.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
model
_parallel
_memory_opt
=
args
.
model_parallel_memory_opt
self
.
sequence
_parallel
=
args
.
sequence_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
...
@@ -515,14 +508,14 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -515,14 +508,14 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
assert
not
self
.
model
_parallel
_memory_opt
assert
not
self
.
sequence
_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
self
.
weight
,
None
,
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
,
None
)
self
.
gradient_accumulation_fusion
,
None
,
None
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
if
self
.
model
_parallel
_memory_opt
:
if
self
.
sequence
_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
...
...
megatron/optimizer/__init__.py
View file @
cfd2e216
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
apex.optimizers
import
FusedSGD
as
SGD
...
@@ -90,6 +91,18 @@ def get_megatron_optimizer(model,
...
@@ -90,6 +91,18 @@ def get_megatron_optimizer(model,
weight_decay
=
args
.
weight_decay
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
eps
=
args
.
adam_eps
)
# preallocating state tensors to avoid fragmentation
for
param_group
in
optimizer
.
param_groups
:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
state
=
optimizer
.
state
[
param
]
if
len
(
state
)
==
0
:
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
param
.
data
,
dtype
=
torch
.
float
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
param
.
data
,
dtype
=
torch
.
float
)
elif
args
.
optimizer
==
'sgd'
:
elif
args
.
optimizer
==
'sgd'
:
optimizer
=
SGD
(
param_groups
,
optimizer
=
SGD
(
param_groups
,
lr
=
args
.
lr
,
lr
=
args
.
lr
,
...
...
megatron/optimizer/optimizer.py
View file @
cfd2e216
...
@@ -264,14 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -264,14 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
=
self
.
optimizer
.
state
.
pop
(
param
)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params.
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
...
@@ -289,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -289,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
fp32_from_float16_params_this_group
)
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
...
...
megatron/p2p_communication.py
View file @
cfd2e216
...
@@ -62,7 +62,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -62,7 +62,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
override_scatter_gather_tensors_in_pipeline
=
False
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
and
\
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model
_parallel
_memory_opt
:
not
args
.
sequence
_parallel
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
tensor_chunk_shape
=
tensor_chunk_shape
//
\
...
@@ -95,7 +95,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -95,7 +95,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model
_parallel
_memory_opt
:
not
args
.
sequence
_parallel
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
@@ -141,7 +141,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -141,7 +141,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model
_parallel
_memory_opt
:
not
args
.
sequence
_parallel
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
...
megatron/schedules.py
View file @
cfd2e216
...
@@ -279,7 +279,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -279,7 +279,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
args
=
get_args
()
if
args
.
model
_parallel
_memory_opt
:
if
args
.
sequence
_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
seq_length
=
args
.
seq_length
seq_length
=
args
.
seq_length
...
@@ -519,13 +519,13 @@ def get_tensor_shapes(rank, model_type):
...
@@ -519,13 +519,13 @@ def get_tensor_shapes(rank, model_type):
args
=
get_args
()
args
=
get_args
()
tensor_shapes
=
[]
tensor_shapes
=
[]
if
args
.
model
_parallel
_memory_opt
:
if
args
.
sequence
_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
seq_length
=
args
.
seq_length
seq_length
=
args
.
seq_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
args
.
model
_parallel
_memory_opt
:
if
args
.
sequence
_parallel
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
decoder_seq_length
=
args
.
decoder_seq_length
decoder_seq_length
=
args
.
decoder_seq_length
...
...
megatron/training.py
View file @
cfd2e216
...
@@ -422,7 +422,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -422,7 +422,7 @@ def train_step(forward_step_func, data_iterator,
# All-reduce layernorm parameters across model parallel nodes
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
# when sequence parallelism is used
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
args
.
model
_parallel
_memory_opt
:
args
.
sequence
_parallel
:
grads
=
[]
grads
=
[]
for
model_module
in
model
:
for
model_module
in
model
:
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
...
@@ -432,8 +432,6 @@ def train_step(forward_step_func, data_iterator,
...
@@ -432,8 +432,6 @@ def train_step(forward_step_func, data_iterator,
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
grad
.
data
)
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
=
_flatten_dense_tensors
(
grads
)
#TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
...
...
pretrain_gpt.py
View file @
cfd2e216
...
@@ -59,7 +59,7 @@ def get_batch(data_iterator):
...
@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack.
# Unpack.
tokens_
=
data_b
[
'text'
].
long
()
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
transpose
(
0
,
1
).
contiguous
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
# Get the masks and postition ids.
# Get the masks and postition ids.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment