Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4ae54b55
Commit
4ae54b55
authored
Jan 12, 2021
by
Vijay Korthikanti
Browse files
Adressing more review comments
parent
d836d498
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
29 deletions
+44
-29
megatron/arguments.py
megatron/arguments.py
+3
-2
megatron/model/bert_model.py
megatron/model/bert_model.py
+2
-0
megatron/model/classification.py
megatron/model/classification.py
+2
-0
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+1
-1
megatron/model/language_model.py
megatron/model/language_model.py
+32
-25
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+2
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-0
megatron/model/transformer.py
megatron/model/transformer.py
+0
-1
No files found.
megatron/arguments.py
View file @
4ae54b55
...
@@ -539,9 +539,10 @@ def _add_data_args(parser):
...
@@ -539,9 +539,10 @@ def _add_data_args(parser):
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the BPE merge file.'
)
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
"
Maximum sequence length to process.
"
)
help
=
'
Maximum sequence length to process.
'
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum encoder sequence length to process."
)
help
=
'Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length'
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
...
...
megatron/model/bert_model.py
View file @
4ae54b55
...
@@ -19,6 +19,7 @@ import torch
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_layernorm
from
megatron.model
import
import_layernorm
...
@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule):
...
@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/classification.py
View file @
4ae54b55
...
@@ -19,6 +19,7 @@ import torch
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
...
@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule):
...
@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
...
...
megatron/model/gpt2_model.py
View file @
4ae54b55
...
@@ -76,7 +76,7 @@ class GPT2ModelBase(MegatronModule):
...
@@ -76,7 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func
=
gpt2_attention_mask_func
,
attention_mask_func
=
gpt2_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
self
_attn_mask_type
=
AttnMaskType
.
causal
,
encoder
_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
...
...
megatron/model/language_model.py
View file @
4ae54b55
...
@@ -44,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -44,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
add_decoder
=
Fals
e
,
init_method
=
None
,
encoder_attn_mask_typ
e
,
init_method
=
None
,
scaled_init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
self
_attn_mask_type
=
AttnMaskType
.
padding
):
decoder
_attn_mask_type
=
AttnMaskType
.
causal
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
args
=
get_args
()
...
@@ -58,14 +58,15 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -58,14 +58,15 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args
.
num_layers
)
args
.
num_layers
)
# Language model.
# Language model.
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
,
encoder_attn_mask_type
]
kwargs
=
{}
kwargs
=
{}
cls
=
None
cls
=
None
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'self_attn_mask_type'
]
=
self_attn_mask_type
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'decoder_attn_mask_type'
]
=
decoder_attn_mask_type
kwargs
[
'add_pooler'
]
=
add_pooler
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
cls
=
TransformerLanguageModelFirstStage
...
@@ -192,6 +193,8 @@ class Embedding(MegatronModule):
...
@@ -192,6 +193,8 @@ class Embedding(MegatronModule):
if
tokentype_ids
is
not
None
:
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Dropout.
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
@@ -284,9 +287,10 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -284,9 +287,10 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -294,8 +298,9 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -294,8 +298,9 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
init_method
self
.
self
_attn_mask_type
=
self
_attn_mask_type
self
.
encoder
_attn_mask_type
=
encoder
_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
# Embeddings.
# Embeddings.
...
@@ -313,7 +318,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -313,7 +318,7 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func
,
attention_mask_func
,
self
.
init_method
,
self
.
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self
.
encoder
_attn_mask_type
)
self
.
_encoder_key
=
'encoder'
self
.
_encoder_key
=
'encoder'
# Decoder
# Decoder
...
@@ -325,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -325,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
init_method
,
self
.
init_method
,
output_layer_init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
A
ttn
M
ask
T
ype
.
causal
)
self_attn_mask_type
=
self
.
decoder_a
ttn
_m
ask
_t
ype
)
self
.
_decoder_key
=
'decoder'
self
.
_decoder_key
=
'decoder'
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
@@ -334,7 +339,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -334,7 +339,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
enc_language_model_input
,
enc_att
entio
n_mask
,
def
forward
(
self
,
enc_language_model_input
,
enc_attn_mask
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
...
@@ -352,7 +357,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -352,7 +357,7 @@ class TransformerLanguageModelBase(MegatronModule):
# encoder.
# encoder.
if
enc_hidden_states
is
None
:
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_att
entio
n_mask
,
enc_attn_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
get_key_value
=
get_key_value
)
else
:
else
:
...
@@ -438,8 +443,8 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -438,8 +443,8 @@ class TransformerLanguageModelBase(MegatronModule):
# for backward compatibility.
# for backward compatibility.
state_dict_
=
{}
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
if
'
encod
er.'
in
key
:
if
'
transform
er.'
in
key
:
state_dict_
[
key
.
split
(
'
encod
er.'
)[
1
]]
=
state_dict
[
key
]
state_dict_
[
key
.
split
(
'
transform
er.'
)[
1
]]
=
state_dict
[
key
]
# for backward compatibility.
# for backward compatibility.
state_dict_self_attention
=
{}
state_dict_self_attention
=
{}
...
@@ -477,27 +482,29 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
...
@@ -477,27 +482,29 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
self
_attn_mask_type
=
AttnMaskType
.
padding
,
decoder
_attn_mask_type
=
AttnMaskType
.
causal
,
add_decoder
=
False
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
super
(
TransformerLanguageModel
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
,
add_decoder
=
add_decoder
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_att
entio
n_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
enc_input_ids
,
enc_position_ids
),
(
enc_input_ids
,
enc_position_ids
),
enc_att
entio
n_mask
,
enc_attn_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
...
@@ -519,14 +526,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
...
@@ -519,14 +526,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_token
type
s
=
0
,
encoder_attn_mask_
type
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
num_tokentypes
=
0
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
num_token
type
s
,
encoder_attn_mask_
type
,
self_attn_mask_type
=
self_attn_mask_
type
)
num_tokentypes
=
num_token
type
s
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
...
@@ -548,12 +555,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
...
@@ -548,12 +555,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self
_attn_mask_type
=
AttnMaskType
.
padding
):
encoder
_attn_mask_type
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
_attn_mask_type
)
encoder
_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
layer_past
=
None
,
get_key_value
=
False
):
...
@@ -574,13 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
...
@@ -574,13 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self
_attn_mask_type
=
AttnMaskType
.
padding
,
encoder
_attn_mask_type
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self
_attn_mask_type
=
AttnMaskType
.
padding
,
encoder
_attn_mask_type
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
...
...
megatron/model/multiple_choice.py
View file @
4ae54b55
...
@@ -19,6 +19,7 @@ import torch
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
...
@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule):
...
@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
...
...
megatron/model/realm_model.py
View file @
4ae54b55
...
@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
...
@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
...
@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule):
...
@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule):
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/transformer.py
View file @
4ae54b55
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
"""Transformer."""
"""Transformer."""
import
enum
import
math
import
math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment