Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
deffcb6a
Commit
deffcb6a
authored
Mar 31, 2020
by
Mohammad
Browse files
arguments in the model refactored
parent
601b19b7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
126 additions
and
384 deletions
+126
-384
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+13
-39
megatron/model/classification.py
megatron/model/classification.py
+9
-38
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+7
-36
megatron/model/language_model.py
megatron/model/language_model.py
+26
-58
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+10
-37
megatron/model/transformer.py
megatron/model/transformer.py
+52
-125
pretrain_bert.py
pretrain_bert.py
+2
-15
pretrain_gpt2.py
pretrain_gpt2.py
+1
-14
tasks/glue/finetune.py
tasks/glue/finetune.py
+1
-11
tasks/race/finetune.py
tasks/race/finetune.py
+1
-11
No files found.
megatron/arguments.py
View file @
deffcb6a
...
@@ -108,6 +108,10 @@ def _add_network_size_args(parser):
...
@@ -108,6 +108,10 @@ def _add_network_size_args(parser):
'This is added for computational efficieny reasons.'
)
'This is added for computational efficieny reasons.'
)
group
.
add_argument
(
'--layernorm-epsilon'
,
type
=
float
,
default
=
1e-5
,
group
.
add_argument
(
'--layernorm-epsilon'
,
type
=
float
,
default
=
1e-5
,
help
=
'Layer norm epsilon.'
)
help
=
'Layer norm epsilon.'
)
group
.
add_argument
(
'--apply-residual-connection-post-layernorm'
,
action
=
'store_true'
,
help
=
'If set, use original BERT residula connection '
'ordering.'
)
return
parser
return
parser
...
...
megatron/model/bert_model.py
View file @
deffcb6a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
.language_model
import
parallel_lm_logits
from
.language_model
import
parallel_lm_logits
...
@@ -106,60 +107,33 @@ class BertLMHead(MegatronModule):
...
@@ -106,60 +107,33 @@ class BertLMHead(MegatronModule):
class
BertModel
(
MegatronModule
):
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
"""Bert Language model."""
def
__init__
(
self
,
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
num_layers
,
parallel_output
=
True
):
vocab_size
,
hidden_size
,
num_attention_heads
,
embedding_dropout_prob
,
attention_dropout_prob
,
output_dropout_prob
,
max_sequence_length
,
checkpoint_activations
,
checkpoint_num_layers
=
1
,
add_binary_head
=
False
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
num_tokentypes
=
0
,
parallel_output
=
True
,
apply_query_key_layer_scaling
=
False
,
attention_softmax_in_fp32
=
False
):
super
(
BertModel
,
self
).
__init__
()
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
add_binary_head
=
add_binary_head
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_layers
=
num_layers
,
attention_mask_func
=
bert_attention_mask_func
,
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
embedding_dropout_prob
=
embedding_dropout_prob
,
attention_dropout_prob
=
attention_dropout_prob
,
output_dropout_prob
=
output_dropout_prob
,
max_sequence_length
=
max_sequence_length
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
add_pooler
=
self
.
add_binary_head
,
attention_mask_func
=
bert_attention_mask_func
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_num_layers
=
checkpoint_num_layers
,
layernorm_epsilon
=
layernorm_epsilon
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
scaled_init_method
=
scaled_init_method
)
num_layers
),
residual_connection_post_layernorm
=
False
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
)
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
_lm_head_key
=
'lm_head'
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
hidden_size
,
2
,
init_method
)
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
init_method
)
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
...
...
megatron/model/classification.py
View file @
deffcb6a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.bert_model
import
bert_position_ids
...
@@ -30,54 +31,24 @@ from megatron import print_rank_0
...
@@ -30,54 +31,24 @@ from megatron import print_rank_0
class
Classification
(
MegatronModule
):
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
num_classes
,
num_layers
,
vocab_size
,
hidden_size
,
num_attention_heads
,
embedding_dropout_prob
,
attention_dropout_prob
,
output_dropout_prob
,
max_sequence_length
,
checkpoint_activations
,
checkpoint_num_layers
=
1
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
num_tokentypes
=
2
,
apply_query_key_layer_scaling
=
False
,
attention_softmax_in_fp32
=
False
):
super
(
Classification
,
self
).
__init__
()
super
(
Classification
,
self
).
__init__
()
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
init_method
=
init_method_normal
(
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_layers
=
num_layers
,
attention_mask_func
=
bert_attention_mask_func
,
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
embedding_dropout_prob
=
embedding_dropout_prob
,
attention_dropout_prob
=
attention_dropout_prob
,
output_dropout_prob
=
output_dropout_prob
,
max_sequence_length
=
max_sequence_length
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
attention_mask_func
=
bert_attention_mask_func
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_num_layers
=
checkpoint_num_layers
,
layernorm_epsilon
=
layernorm_epsilon
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
num_layers
),
args
.
num_layers
))
residual_connection_post_layernorm
=
False
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
# Multi-choice head.
# Multi-choice head.
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
output
_dropout
_prob
)
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden
_dropout
)
self
.
classification_head
=
get_linear_layer
(
hidden_size
,
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
num_classes
,
self
.
num_classes
,
init_method
)
init_method
)
self
.
_classification_head_key
=
'classification_head'
self
.
_classification_head_key
=
'classification_head'
...
...
megatron/model/gpt2_model.py
View file @
deffcb6a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
.language_model
import
parallel_lm_logits
from
.language_model
import
parallel_lm_logits
...
@@ -34,49 +35,19 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
...
@@ -34,49 +35,19 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
class
GPT2Model
(
MegatronModule
):
class
GPT2Model
(
MegatronModule
):
"""GPT-2 Language model."""
"""GPT-2 Language model."""
def
__init__
(
self
,
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
num_layers
,
vocab_size
,
hidden_size
,
num_attention_heads
,
embedding_dropout_prob
,
attention_dropout_prob
,
output_dropout_prob
,
max_sequence_length
,
checkpoint_activations
,
checkpoint_num_layers
=
1
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
num_tokentypes
=
0
,
parallel_output
=
True
,
apply_query_key_layer_scaling
=
False
,
attention_softmax_in_fp32
=
False
):
super
(
GPT2Model
,
self
).
__init__
()
super
(
GPT2Model
,
self
).
__init__
()
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_layers
=
num_layers
,
attention_mask_func
=
gpt2_attention_mask_func
,
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
embedding_dropout_prob
=
embedding_dropout_prob
,
attention_dropout_prob
=
attention_dropout_prob
,
output_dropout_prob
=
output_dropout_prob
,
max_sequence_length
=
max_sequence_length
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
attention_mask_func
=
gpt2_attention_mask_func
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
checkpoint_activations
=
checkpoint_activations
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
checkpoint_num_layers
=
checkpoint_num_layers
,
args
.
num_layers
))
layernorm_epsilon
=
layernorm_epsilon
,
init_method
=
init_method_normal
(
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
num_layers
),
residual_connection_post_layernorm
=
False
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
...
...
megatron/model/language_model.py
View file @
deffcb6a
...
@@ -18,13 +18,13 @@
...
@@ -18,13 +18,13 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
.transformer
import
TransformerHyperparameters
from
megatron.model.utils
import
gelu
from
.utils
import
gelu
from
megatron.model.utils
import
get_linear_layer
from
.utils
import
get_linear_layer
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
...
@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Gather if needed.
# Gather if needed.
if
parallel_output
:
if
parallel_output
:
return
logits_parallel
return
logits_parallel
else
:
return
mpu
.
gather_from_model_parallel_region
(
logits_parallel
)
return
mpu
.
gather_from_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_layers
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
vocab_size
,
init_method
,
scaled_init_method
):
hidden_size
,
"""Build language model and return along with the key to save."""
num_attention_heads
,
embedding_dropout_prob
,
attention_dropout_prob
,
output_dropout_prob
,
max_sequence_length
,
num_tokentypes
,
attention_mask_func
,
add_pooler
,
checkpoint_activations
,
checkpoint_num_layers
,
layernorm_epsilon
,
init_method
,
scaled_init_method
,
residual_connection_post_layernorm
,
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
):
# Transformer hyperparameters.
transformer_hparams
=
TransformerHyperparameters
(
hidden_size
=
hidden_size
,
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
attention_dropout_prob
=
attention_dropout_prob
,
output_dropout_prob
=
output_dropout_prob
,
mlp_activation_func
=
gelu
,
layernorm_epsilon
=
layernorm_epsilon
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_num_layers
=
checkpoint_num_layers
,
apply_residual_connection_post_layernorm
=
residual_connection_post_layernorm
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
# Language model.
# Language model.
language_model
=
TransformerLanguageModel
(
language_model
=
TransformerLanguageModel
(
transformer_hparams
=
transformer_hparams
,
attention_mask_func
=
attention_mask_func
,
attention_mask_func
=
attention_mask_func
,
vocab_size
=
vocab_size
,
mlp_activation_func
=
gelu
,
max_sequence_length
=
max_sequence_length
,
init_method
=
init_method
,
embedding_dropout_prob
=
embedding_dropout_prob
,
output_layer_init_method
=
scaled_init_method
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
# key used for checkpoints.
# key used for checkpoints.
...
@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule):
will ignore this embedding
will ignore this embedding
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
transformer_hparams
,
attention_mask_func
,
attention_mask_func
,
vocab_size
,
mlp_activation_func
,
max_sequence_length
,
init_method
,
embedding_dropout_prob
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
transformer_hparams
[
'
hidden_size
'
]
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
transformer_hparams
[
'
init_method
'
]
self
.
init_method
=
init_method
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
# Embeddings
# Embeddings
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
vocab_size
,
args
.
padded_
vocab_size
,
max_sequence_length
,
args
.
max_position_embeddings
,
embedding
_dropout
_prob
,
args
.
hidden
_dropout
,
self
.
init_method
,
self
.
init_method
,
self
.
num_tokentypes
)
self
.
num_tokentypes
)
self
.
_embedding_key
=
'embedding'
self
.
_embedding_key
=
'embedding'
# Transformer
# Transformer
self
.
transformer
=
ParallelTransformer
(
self
.
transformer
=
ParallelTransformer
(
transformer_hparams
,
attention_mask_func
,
mlp_activation_func
,
attention_mask_func
)
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
self
.
_transformer_key
=
'transformer'
# Pooler
# Pooler
...
...
megatron/model/multiple_choice.py
View file @
deffcb6a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.bert_model
import
bert_position_ids
...
@@ -30,52 +31,24 @@ from megatron import print_rank_0
...
@@ -30,52 +31,24 @@ from megatron import print_rank_0
class
MultipleChoice
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
def
__init__
(
self
,
num_tokentypes
=
2
):
num_layers
,
vocab_size
,
hidden_size
,
num_attention_heads
,
embedding_dropout_prob
,
attention_dropout_prob
,
output_dropout_prob
,
max_sequence_length
,
checkpoint_activations
,
checkpoint_num_layers
=
1
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
num_tokentypes
=
2
,
apply_query_key_layer_scaling
=
False
,
attention_softmax_in_fp32
=
False
):
super
(
MultipleChoice
,
self
).
__init__
()
super
(
MultipleChoice
,
self
).
__init__
()
args
=
get_args
()
init_method
=
init_method_normal
(
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_layers
=
num_layers
,
attention_mask_func
=
bert_attention_mask_func
,
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
embedding_dropout_prob
=
embedding_dropout_prob
,
attention_dropout_prob
=
attention_dropout_prob
,
output_dropout_prob
=
output_dropout_prob
,
max_sequence_length
=
max_sequence_length
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
attention_mask_func
=
bert_attention_mask_func
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_num_layers
=
checkpoint_num_layers
,
layernorm_epsilon
=
layernorm_epsilon
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
num_layers
),
args
.
num_layers
))
residual_connection_post_layernorm
=
False
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
# Multi-choice head.
# Multi-choice head.
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
output_dropout_prob
)
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_head
=
get_linear_layer
(
hidden_size
,
1
,
init_method
)
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
self
.
_multichoice_head_key
=
'multichoice_head'
...
...
megatron/model/transformer.py
View file @
deffcb6a
...
@@ -20,6 +20,7 @@ import math
...
@@ -20,6 +20,7 @@ import math
import
torch
import
torch
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
...
@@ -45,85 +46,6 @@ from megatron.module import MegatronModule
...
@@ -45,85 +46,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
unmaksed-attention-scores, attention-mask)
"""
"""
class
TransformerHyperparameters
:
"""Hyperparameters used to build and run the transformer.
Arguments:
hidden_size: hidden size (h)
num_layers: number of layers (l)
num_attention_heads: number of attention heads (n)
attention_dropout_prob: dropout probability for the attention
probabiliies
output_dropout_prob: dropout probability for the output
layers (attention output and mlp output)
mlp_activation_func: activation function for the mlp layer
layernorm_epsilon: tolerance parameters used for layer norm
dividions
init_method: init method used for all weights except layer
norm and output weights
output_layer_init_method: init method for output weights (
attention output and mlp output)
checkpoint_activations: flag to use activation checkpointing
checkpoint_num_layers: number of layers use in each chunk of
activation checkpointing
apply_residual_connection_post_layernorm: Take the post layer-norm
values for resudual connecton. BERT: True, GPT-2: False
"""
def
__init__
(
self
,
hidden_size
=
None
,
num_layers
=
None
,
num_attention_heads
=
None
,
attention_dropout_prob
=
None
,
output_dropout_prob
=
None
,
mlp_activation_func
=
None
,
layernorm_epsilon
=
None
,
init_method
=
None
,
output_layer_init_method
=
None
,
checkpoint_activations
=
None
,
checkpoint_num_layers
=
None
,
apply_residual_connection_post_layernorm
=
None
,
apply_query_key_layer_scaling
=
None
,
attention_softmax_in_fp32
=
None
):
self
.
params_dict
=
{}
self
.
params_dict
[
'hidden_size'
]
=
hidden_size
self
.
params_dict
[
'num_layers'
]
=
num_layers
self
.
params_dict
[
'num_attention_heads'
]
=
num_attention_heads
self
.
params_dict
[
'attention_dropout_prob'
]
=
attention_dropout_prob
self
.
params_dict
[
'output_dropout_prob'
]
=
output_dropout_prob
self
.
params_dict
[
'mlp_activation_func'
]
=
mlp_activation_func
self
.
params_dict
[
'layernorm_epsilon'
]
=
layernorm_epsilon
self
.
params_dict
[
'init_method'
]
=
init_method
self
.
params_dict
[
'output_layer_init_method'
]
=
output_layer_init_method
self
.
params_dict
[
'checkpoint_activations'
]
=
checkpoint_activations
self
.
params_dict
[
'checkpoint_num_layers'
]
=
checkpoint_num_layers
self
.
params_dict
[
'apply_residual_connection_post_layernorm'
]
\
=
apply_residual_connection_post_layernorm
self
.
params_dict
[
'apply_query_key_layer_scaling'
]
\
=
apply_query_key_layer_scaling
self
.
params_dict
[
'attention_softmax_in_fp32'
]
\
=
attention_softmax_in_fp32
def
__getitem__
(
self
,
key
):
"""Custom retrieval with error checks."""
try
:
value
=
self
.
params_dict
[
key
]
except
KeyError
:
raise
Exception
(
'could not find {} in transformer hyperparameters'
.
format
(
key
))
except
Exception
as
e
:
print
(
'unexpected error in transformer hyperparameters:'
,
e
)
raise
Exception
()
else
:
assert
value
is
not
None
,
\
'parameter value for {} is not set in transformer '
\
'hyperparameters'
.
format
(
key
)
return
value
raise
Exception
(
'should not be here'
)
class
ParallelMLP
(
MegatronModule
):
class
ParallelMLP
(
MegatronModule
):
"""MLP.
"""MLP.
...
@@ -133,26 +55,28 @@ class ParallelMLP(MegatronModule):
...
@@ -133,26 +55,28 @@ class ParallelMLP(MegatronModule):
applied.
applied.
"""
"""
def
__init__
(
self
,
hyperparameters
):
def
__init__
(
self
,
mlp_activation_func
,
init_method
,
output_layer_init_method
):
super
(
ParallelMLP
,
self
).
__init__
()
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
# Project to 4h.
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
4
*
hyperparameters
[
'
hidden_size
'
]
,
4
*
args
.
hidden_size
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
hyperparameters
[
'
init_method
'
]
)
init_method
=
init_method
)
self
.
activation_func
=
hyperparameters
[
'
mlp_activation_func
'
]
self
.
activation_func
=
mlp_activation_func
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
hyperparameters
[
'
hidden_size
'
]
,
4
*
args
.
hidden_size
,
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
hyperparameters
[
'
output_layer_init_method
'
]
)
init_method
=
output_layer_init_method
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
hyperparameters
[
'output
_dropout
_prob'
]
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden
_dropout
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -174,51 +98,47 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -174,51 +98,47 @@ class ParallelSelfAttention(MegatronModule):
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
and returns output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
,
layer_number
):
output_layer_init_method
,
layer_number
):
super
(
ParallelSelfAttention
,
self
).
__init__
()
super
(
ParallelSelfAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
attention_mask_func
=
attention_mask_func
self
.
attention_mask_func
=
attention_mask_func
self
.
apply_query_key_layer_scaling
\
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
=
hyperparameters
[
'apply_query_key_layer_scaling'
]
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
self
.
attention_softmax_in_fp32
\
=
hyperparameters
[
'attention_softmax_in_fp32'
]
if
self
.
apply_query_key_layer_scaling
:
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_model_parallel_world_size
()
world_size
=
mpu
.
get_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
self
.
hidden_size_per_partition
=
mpu
.
divide
(
args
.
hidden_size
,
hyperparameters
[
'hidden_size'
],
world_size
)
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
hyperparameters
[
'hidden_size'
],
args
.
hidden_size
,
args
.
num_attention_heads
)
hyperparameters
[
'num_attention_heads'
])
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
hyperparameters
[
'
num_attention_heads
'
]
,
world_size
)
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
# Strided linear layer.
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
3
*
hyperparameters
[
'
hidden_size
'
]
,
3
*
args
.
hidden_size
,
stride
=
3
,
stride
=
3
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
hyperparameters
[
'
init_method
'
]
)
init_method
=
init_method
)
# Dropout. Note that for a single iteration, this layer will generate
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
hyperparameters
[
'attention_dropout_prob'
])
# Output.
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
self
.
dense
=
mpu
.
RowParallelLinear
(
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
hyperparameters
[
'output_layer_init_method'
])
init_method
=
output_layer_init_method
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
hyperparameters
[
'output_dropout_prob'
])
def
_transpose_for_scores
(
self
,
tensor
):
def
_transpose_for_scores
(
self
,
tensor
):
...
@@ -369,30 +289,34 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -369,30 +289,34 @@ class ParallelTransformerLayer(MegatronModule):
Transformore layer takes input with size [b, s, h] and returns an
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
,
layer_number
):
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_number
=
layer_number
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
hyperparameters
[
'
apply_residual_connection_post_layernorm
'
]
=
args
.
apply_residual_connection_post_layernorm
# Layernorm on the input data.
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
eps
=
hyperparameters
[
'
layernorm_epsilon
'
]
)
eps
=
args
.
layernorm_epsilon
)
# Self attention.
# Self attention.
self
.
attention
=
ParallelSelfAttention
(
self
.
attention
=
ParallelSelfAttention
(
attention_mask_func
,
init_method
,
hyperparameters
,
attention_mask_func
,
layer_number
)
output_layer_init_method
,
layer_number
)
# Layernorm on the input data.
# Layernorm on the input data.
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
eps
=
hyperparameters
[
'
layernorm_epsilon
'
]
)
eps
=
args
.
layernorm_epsilon
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
hyperparameters
)
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
...
@@ -434,25 +358,28 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -434,25 +358,28 @@ class ParallelTransformerLayer(MegatronModule):
class
ParallelTransformer
(
MegatronModule
):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
"""Transformer class."""
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
):
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
hyperparameters
[
'
checkpoint_activations
'
]
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
checkpoint_num_layers
=
hyperparameters
[
'
checkpoint_num_layers
'
]
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
def
get_layer
(
layer_number
):
def
get_layer
(
layer_number
):
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
hyperparameters
,
attention_mask_func
,
layer_number
)
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
)
# Transformer layers.
# Transformer layers.
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
(
i
+
1
)
for
i
in
range
(
hyperparameters
[
'
num_layers
'
]
)])
[
get_layer
(
i
+
1
)
for
i
in
range
(
args
.
num_layers
)])
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
hyperparameters
[
'
hidden_size
'
]
,
args
.
hidden_size
,
eps
=
hyperparameters
[
'
layernorm_epsilon
'
]
)
eps
=
args
.
layernorm_epsilon
)
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
...
...
pretrain_bert.py
View file @
deffcb6a
...
@@ -36,22 +36,9 @@ def model_provider():
...
@@ -36,22 +36,9 @@ def model_provider():
print_rank_0
(
'building BERT model ...'
)
print_rank_0
(
'building BERT model ...'
)
model
=
BertModel
(
model
=
BertModel
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
padded_vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
add_binary_head
=
True
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
parallel_output
=
True
,
add_binary_head
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
parallel_output
=
True
)
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
return
model
return
model
...
...
pretrain_gpt2.py
View file @
deffcb6a
...
@@ -37,20 +37,7 @@ def model_provider():
...
@@ -37,20 +37,7 @@ def model_provider():
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_layers
=
args
.
num_layers
,
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
True
)
vocab_size
=
args
.
padded_vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
return
model
return
model
...
...
tasks/glue/finetune.py
View file @
deffcb6a
...
@@ -46,17 +46,7 @@ def glue_classification(num_classes, Dataset,
...
@@ -46,17 +46,7 @@ def glue_classification(num_classes, Dataset,
print_rank_0
(
'building classification model for {} ...'
.
format
(
print_rank_0
(
'building classification model for {} ...'
.
format
(
args
.
task
))
args
.
task
))
return
Classification
(
return
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
num_classes
=
num_classes
,
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
padded_vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
)
def
metrics_func_provider
():
def
metrics_func_provider
():
...
...
tasks/race/finetune.py
View file @
deffcb6a
...
@@ -39,20 +39,10 @@ def train_valid_datasets_provider():
...
@@ -39,20 +39,10 @@ def train_valid_datasets_provider():
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building multichoice model for RACE ...'
)
print_rank_0
(
'building multichoice model for RACE ...'
)
return
MultipleChoice
(
return
MultipleChoice
(
num_tokentypes
=
2
)
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
padded_vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
)
def
metrics_func_provider
():
def
metrics_func_provider
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment