Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4ae54b55
Commit
4ae54b55
authored
Jan 12, 2021
by
Vijay Korthikanti
Browse files
Adressing more review comments
parent
d836d498
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
29 deletions
+44
-29
megatron/arguments.py
megatron/arguments.py
+3
-2
megatron/model/bert_model.py
megatron/model/bert_model.py
+2
-0
megatron/model/classification.py
megatron/model/classification.py
+2
-0
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+1
-1
megatron/model/language_model.py
megatron/model/language_model.py
+32
-25
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+2
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-0
megatron/model/transformer.py
megatron/model/transformer.py
+0
-1
No files found.
megatron/arguments.py
View file @
4ae54b55
...
...
@@ -539,9 +539,10 @@ def _add_data_args(parser):
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
"
Maximum sequence length to process.
"
)
help
=
'
Maximum sequence length to process.
'
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum encoder sequence length to process."
)
help
=
'Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length'
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
...
...
megatron/model/bert_model.py
View file @
4ae54b55
...
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_layernorm
...
...
@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/classification.py
View file @
4ae54b55
...
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
...
...
@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/gpt2_model.py
View file @
4ae54b55
...
...
@@ -76,7 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func
=
gpt2_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
self
_attn_mask_type
=
AttnMaskType
.
causal
,
encoder
_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/language_model.py
View file @
4ae54b55
...
...
@@ -44,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
add_decoder
=
Fals
e
,
init_method
=
None
,
scaled_init_method
=
None
,
self
_attn_mask_type
=
AttnMaskType
.
padding
):
encoder_attn_mask_typ
e
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
decoder
_attn_mask_type
=
AttnMaskType
.
causal
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
...
...
@@ -58,14 +58,15 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args
.
num_layers
)
# Language model.
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
,
encoder_attn_mask_type
]
kwargs
=
{}
cls
=
None
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'self_attn_mask_type'
]
=
self_attn_mask_type
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'decoder_attn_mask_type'
]
=
decoder_attn_mask_type
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
...
...
@@ -192,6 +193,8 @@ class Embedding(MegatronModule):
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
...
@@ -284,9 +287,10 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -294,8 +298,9 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
self
_attn_mask_type
=
self
_attn_mask_type
self
.
encoder
_attn_mask_type
=
encoder
_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
# Embeddings.
...
...
@@ -313,7 +318,7 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self
.
encoder
_attn_mask_type
)
self
.
_encoder_key
=
'encoder'
# Decoder
...
...
@@ -325,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
A
ttn
M
ask
T
ype
.
causal
)
self_attn_mask_type
=
self
.
decoder_a
ttn
_m
ask
_t
ype
)
self
.
_decoder_key
=
'decoder'
if
mpu
.
is_pipeline_last_stage
():
...
...
@@ -334,7 +339,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
enc_language_model_input
,
enc_att
entio
n_mask
,
def
forward
(
self
,
enc_language_model_input
,
enc_attn_mask
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
...
...
@@ -352,7 +357,7 @@ class TransformerLanguageModelBase(MegatronModule):
# encoder.
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_att
entio
n_mask
,
enc_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
else
:
...
...
@@ -438,8 +443,8 @@ class TransformerLanguageModelBase(MegatronModule):
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'
encod
er.'
in
key
:
state_dict_
[
key
.
split
(
'
encod
er.'
)[
1
]]
=
state_dict
[
key
]
if
'
transform
er.'
in
key
:
state_dict_
[
key
.
split
(
'
transform
er.'
)[
1
]]
=
state_dict
[
key
]
# for backward compatibility.
state_dict_self_attention
=
{}
...
...
@@ -477,27 +482,29 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
attention_mask_func
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
self
_attn_mask_type
=
AttnMaskType
.
padding
,
decoder
_attn_mask_type
=
AttnMaskType
.
causal
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_att
entio
n_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
enc_input_ids
,
enc_position_ids
),
enc_att
entio
n_mask
,
enc_attn_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
...
...
@@ -519,14 +526,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_token
type
s
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
encoder_attn_mask_
type
,
num_tokentypes
=
0
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
num_token
type
s
,
self_attn_mask_type
=
self_attn_mask_
type
)
encoder_attn_mask_
type
,
num_tokentypes
=
num_token
type
s
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
...
...
@@ -548,12 +555,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
attention_mask_func
,
init_method
,
output_layer_init_method
,
self
_attn_mask_type
=
AttnMaskType
.
padding
):
encoder
_attn_mask_type
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
_attn_mask_type
)
encoder
_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
...
...
@@ -574,13 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func
,
init_method
,
output_layer_init_method
,
self
_attn_mask_type
=
AttnMaskType
.
padding
,
encoder
_attn_mask_type
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
self
_attn_mask_type
=
AttnMaskType
.
padding
,
encoder
_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
...
...
megatron/model/multiple_choice.py
View file @
4ae54b55
...
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
...
...
@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/realm_model.py
View file @
4ae54b55
...
...
@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from
megatron.model
import
BertModel
from
.module
import
MegatronModule
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.language_model
import
get_language_model
...
...
@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/transformer.py
View file @
4ae54b55
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
"""Transformer."""
import
enum
import
math
import
torch
import
torch.nn.functional
as
F
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment