Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
691747b1
Commit
691747b1
authored
Jan 13, 2020
by
Mohammad Shoeybi
Browse files
added query-key layer scaling and softmax fp32 option
parent
7a3b4c15
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
69 additions
and
20 deletions
+69
-20
arguments.py
arguments.py
+6
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+6
-2
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+6
-2
megatron/model/language_model.py
megatron/model/language_model.py
+6
-2
megatron/model/transformer.py
megatron/model/transformer.py
+36
-11
pretrain_albert.py
pretrain_albert.py
+3
-1
pretrain_bert.py
pretrain_bert.py
+3
-1
pretrain_gpt2.py
pretrain_gpt2.py
+3
-1
No files found.
arguments.py
View file @
691747b1
...
@@ -70,6 +70,12 @@ def add_fp16_config_args(parser):
...
@@ -70,6 +70,12 @@ def add_fp16_config_args(parser):
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model in fp16 mode'
)
help
=
'Run model in fp16 mode'
)
group
.
add_argument
(
'--apply-query-key-layer-scaling'
,
action
=
'store_true'
,
help
=
'Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set '
'attention-softmax-in-fp32 to true'
)
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
help
=
'Run attention masking and softmax in fp32.'
)
group
.
add_argument
(
'--fp32-embedding'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp32-embedding'
,
action
=
'store_true'
,
help
=
'embedding in fp32'
)
help
=
'embedding in fp32'
)
group
.
add_argument
(
'--fp32-layernorm'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp32-layernorm'
,
action
=
'store_true'
,
...
...
megatron/model/bert_model.py
View file @
691747b1
...
@@ -119,7 +119,9 @@ class BertModel(MegatronModule):
...
@@ -119,7 +119,9 @@ class BertModel(MegatronModule):
layernorm_epsilon
=
1.0e-5
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
init_method_std
=
0.02
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
parallel_output
=
True
):
parallel_output
=
True
,
apply_query_key_layer_scaling
=
False
,
attention_softmax_in_fp32
=
False
):
super
(
BertModel
,
self
).
__init__
()
super
(
BertModel
,
self
).
__init__
()
...
@@ -145,7 +147,9 @@ class BertModel(MegatronModule):
...
@@ -145,7 +147,9 @@ class BertModel(MegatronModule):
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
num_layers
),
num_layers
),
residual_connection_post_layernorm
=
False
)
residual_connection_post_layernorm
=
False
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
...
...
megatron/model/gpt2_model.py
View file @
691747b1
...
@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule):
...
@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule):
layernorm_epsilon
=
1.0e-5
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
init_method_std
=
0.02
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
parallel_output
=
True
):
parallel_output
=
True
,
apply_query_key_layer_scaling
=
False
,
attention_softmax_in_fp32
=
False
):
super
(
GPT2Model
,
self
).
__init__
()
super
(
GPT2Model
,
self
).
__init__
()
...
@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule):
...
@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule):
init_method
=
init_method_normal
(
init_method_std
),
init_method
=
init_method_normal
(
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
num_layers
),
num_layers
),
residual_connection_post_layernorm
=
False
)
residual_connection_post_layernorm
=
False
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
...
...
megatron/model/language_model.py
View file @
691747b1
...
@@ -60,7 +60,9 @@ def get_language_model(num_layers,
...
@@ -60,7 +60,9 @@ def get_language_model(num_layers,
layernorm_epsilon
,
layernorm_epsilon
,
init_method
,
init_method
,
scaled_init_method
,
scaled_init_method
,
residual_connection_post_layernorm
):
residual_connection_post_layernorm
,
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
):
# Transformer hyperparameters.
# Transformer hyperparameters.
transformer_hparams
=
TransformerHyperparameters
(
transformer_hparams
=
TransformerHyperparameters
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
...
@@ -74,7 +76,9 @@ def get_language_model(num_layers,
...
@@ -74,7 +76,9 @@ def get_language_model(num_layers,
output_layer_init_method
=
scaled_init_method
,
output_layer_init_method
=
scaled_init_method
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_num_layers
=
checkpoint_num_layers
,
checkpoint_num_layers
=
checkpoint_num_layers
,
apply_residual_connection_post_layernorm
=
residual_connection_post_layernorm
)
apply_residual_connection_post_layernorm
=
residual_connection_post_layernorm
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
# Language model.
# Language model.
language_model
=
TransformerLanguageModel
(
language_model
=
TransformerLanguageModel
(
transformer_hparams
=
transformer_hparams
,
transformer_hparams
=
transformer_hparams
,
...
...
megatron/model/transformer.py
View file @
691747b1
...
@@ -82,7 +82,9 @@ class TransformerHyperparameters:
...
@@ -82,7 +82,9 @@ class TransformerHyperparameters:
output_layer_init_method
=
None
,
output_layer_init_method
=
None
,
checkpoint_activations
=
None
,
checkpoint_activations
=
None
,
checkpoint_num_layers
=
None
,
checkpoint_num_layers
=
None
,
apply_residual_connection_post_layernorm
=
None
):
apply_residual_connection_post_layernorm
=
None
,
apply_query_key_layer_scaling
=
None
,
attention_softmax_in_fp32
=
None
):
self
.
params_dict
=
{}
self
.
params_dict
=
{}
self
.
params_dict
[
'hidden_size'
]
=
hidden_size
self
.
params_dict
[
'hidden_size'
]
=
hidden_size
self
.
params_dict
[
'num_layers'
]
=
num_layers
self
.
params_dict
[
'num_layers'
]
=
num_layers
...
@@ -97,6 +99,10 @@ class TransformerHyperparameters:
...
@@ -97,6 +99,10 @@ class TransformerHyperparameters:
self
.
params_dict
[
'checkpoint_num_layers'
]
=
checkpoint_num_layers
self
.
params_dict
[
'checkpoint_num_layers'
]
=
checkpoint_num_layers
self
.
params_dict
[
'apply_residual_connection_post_layernorm'
]
\
self
.
params_dict
[
'apply_residual_connection_post_layernorm'
]
\
=
apply_residual_connection_post_layernorm
=
apply_residual_connection_post_layernorm
self
.
params_dict
[
'apply_query_key_layer_scaling'
]
\
=
apply_query_key_layer_scaling
self
.
params_dict
[
'attention_softmax_in_fp32'
]
\
=
attention_softmax_in_fp32
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
key
):
...
@@ -169,10 +175,17 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -169,10 +175,17 @@ class ParallelSelfAttention(MegatronModule):
and returns output of the same size.
and returns output of the same size.
"""
"""
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
):
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
,
layer_number
):
super
(
ParallelSelfAttention
,
self
).
__init__
()
super
(
ParallelSelfAttention
,
self
).
__init__
()
self
.
attention_mask_func
=
attention_mask_func
self
.
attention_mask_func
=
attention_mask_func
self
.
apply_query_key_layer_scaling
\
=
hyperparameters
[
'apply_query_key_layer_scaling'
]
self
.
attention_softmax_in_fp32
\
=
hyperparameters
[
'attention_softmax_in_fp32'
]
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_model_parallel_world_size
()
world_size
=
mpu
.
get_model_parallel_world_size
()
...
@@ -239,7 +252,11 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -239,7 +252,11 @@ class ParallelSelfAttention(MegatronModule):
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
"""Unmasked attention scores with size [b, np, s, s]."""
"""Unmasked attention scores with size [b, np, s, s]."""
norm_factor
=
math
.
sqrt
(
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
coeff
=
1
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
norm_factor
=
math
.
sqrt
(
coeff
*
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
return
torch
.
matmul
(
query_layer
/
norm_factor
,
return
torch
.
matmul
(
query_layer
/
norm_factor
,
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
...
@@ -250,7 +267,9 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -250,7 +267,9 @@ class ParallelSelfAttention(MegatronModule):
the size [b, np, s, s].
the size [b, np, s, s].
"""
"""
# Attention probabilities. [b, np, s, s]
# Attention probabilities. [b, np, s, s]
attention_probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
attention_scores
)
if
self
.
apply_query_key_layer_scaling
:
attention_scores
=
attention_scores
*
self
.
layer_number
attention_probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
attention_probs
)
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
...
@@ -304,6 +323,10 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -304,6 +323,10 @@ class ParallelSelfAttention(MegatronModule):
attention_scores
=
self
.
_get_unmasked_attention_scores
(
attention_scores
=
self
.
_get_unmasked_attention_scores
(
query_layer
,
key_layer
)
query_layer
,
key_layer
)
# fp32 conversion.
if
self
.
attention_softmax_in_fp32
:
attention_scores
=
attention_scores
.
float
()
# Apply attention mask. [b, np, s, s]
# Apply attention mask. [b, np, s, s]
if
get_key_value
:
if
get_key_value
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -323,6 +346,10 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -323,6 +346,10 @@ class ParallelSelfAttention(MegatronModule):
# Attention probabilities. [b, np, s, s]
# Attention probabilities. [b, np, s, s]
attention_probs
=
self
.
_get_attention_probs
(
attention_scores
)
attention_probs
=
self
.
_get_attention_probs
(
attention_scores
)
# fp16 conversion
if
self
.
attention_softmax_in_fp32
:
attention_probs
=
attention_probs
.
half
()
# Context layer. [b, s, hp]
# Context layer. [b, s, hp]
context_layer
=
self
.
_get_attended_context
(
attention_probs
,
value_layer
)
context_layer
=
self
.
_get_attended_context
(
attention_probs
,
value_layer
)
...
@@ -342,7 +369,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -342,7 +369,7 @@ class ParallelTransformerLayer(MegatronModule):
Transformore layer takes input with size [b, s, h] and returns an
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
):
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
,
layer_number
):
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
@@ -356,8 +383,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -356,8 +383,7 @@ class ParallelTransformerLayer(MegatronModule):
# Self attention.
# Self attention.
self
.
attention
=
ParallelSelfAttention
(
self
.
attention
=
ParallelSelfAttention
(
hyperparameters
,
hyperparameters
,
attention_mask_func
,
layer_number
)
attention_mask_func
)
# Layernorm on the input data.
# Layernorm on the input data.
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
...
@@ -414,14 +440,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -414,14 +440,13 @@ class ParallelTransformer(MegatronModule):
self
.
checkpoint_activations
=
hyperparameters
[
'checkpoint_activations'
]
self
.
checkpoint_activations
=
hyperparameters
[
'checkpoint_activations'
]
self
.
checkpoint_num_layers
=
hyperparameters
[
'checkpoint_num_layers'
]
self
.
checkpoint_num_layers
=
hyperparameters
[
'checkpoint_num_layers'
]
def
get_layer
():
def
get_layer
(
layer_number
):
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
hyperparameters
,
hyperparameters
,
attention_mask_func
,
layer_number
)
attention_mask_func
)
# Transformer layers.
# Transformer layers.
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
()
for
_
in
range
(
hyperparameters
[
'num_layers'
])])
[
get_layer
(
i
+
1
)
for
i
in
range
(
hyperparameters
[
'num_layers'
])])
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
...
...
pretrain_albert.py
View file @
691747b1
...
@@ -47,7 +47,9 @@ def model_provider(args):
...
@@ -47,7 +47,9 @@ def model_provider(args):
add_binary_head
=
True
,
add_binary_head
=
True
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
args
.
tokentype_size
,
num_tokentypes
=
args
.
tokentype_size
,
parallel_output
=
True
)
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
return
model
return
model
...
...
pretrain_bert.py
View file @
691747b1
...
@@ -46,7 +46,9 @@ def model_provider(args):
...
@@ -46,7 +46,9 @@ def model_provider(args):
add_binary_head
=
True
,
add_binary_head
=
True
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
args
.
tokentype_size
,
num_tokentypes
=
args
.
tokentype_size
,
parallel_output
=
True
)
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
return
model
return
model
...
...
pretrain_gpt2.py
View file @
691747b1
...
@@ -43,7 +43,9 @@ def model_provider(args):
...
@@ -43,7 +43,9 @@ def model_provider(args):
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
parallel_output
=
True
)
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment