Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bcb320ee
Commit
bcb320ee
authored
Mar 26, 2020
by
Neel Kant
Browse files
Add ICT-related parameters to BertModel
parent
21a916b1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
4 deletions
+22
-4
megatron/model/bert_model.py
megatron/model/bert_model.py
+22
-4
No files found.
megatron/model/bert_model.py
View file @
bcb320ee
...
@@ -74,7 +74,7 @@ class BertLMHead(MegatronModule):
...
@@ -74,7 +74,7 @@ class BertLMHead(MegatronModule):
hidden_size: hidden size
hidden_size: hidden size
init_method: init method for weight initialization
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
parallel_output: w
h
ether output logits being distributed or not.
"""
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
):
layernorm_epsilon
,
parallel_output
):
...
@@ -118,6 +118,7 @@ class BertModel(MegatronModule):
...
@@ -118,6 +118,7 @@ class BertModel(MegatronModule):
checkpoint_activations
,
checkpoint_activations
,
checkpoint_num_layers
=
1
,
checkpoint_num_layers
=
1
,
add_binary_head
=
False
,
add_binary_head
=
False
,
ict_head_size
=
None
,
layernorm_epsilon
=
1.0e-5
,
layernorm_epsilon
=
1.0e-5
,
init_method_std
=
0.02
,
init_method_std
=
0.02
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
...
@@ -128,8 +129,13 @@ class BertModel(MegatronModule):
...
@@ -128,8 +129,13 @@ class BertModel(MegatronModule):
super
(
BertModel
,
self
).
__init__
()
super
(
BertModel
,
self
).
__init__
()
self
.
add_binary_head
=
add_binary_head
self
.
add_binary_head
=
add_binary_head
self
.
ict_head_size
=
ict_head_size
self
.
add_ict_head
=
ict_head_size
is
not
None
assert
not
(
self
.
add_binary_head
and
self
.
add_ict_head
)
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
init_method_std
)
init_method
=
init_method_normal
(
init_method_std
)
add_pooler
=
self
.
add_binary_head
or
self
.
add_ict_head
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
...
@@ -141,7 +147,7 @@ class BertModel(MegatronModule):
...
@@ -141,7 +147,7 @@ class BertModel(MegatronModule):
output_dropout_prob
=
output_dropout_prob
,
output_dropout_prob
=
output_dropout_prob
,
max_sequence_length
=
max_sequence_length
,
max_sequence_length
=
max_sequence_length
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
add_pooler
=
add_pooler
,
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_activations
=
checkpoint_activations
,
checkpoint_num_layers
=
checkpoint_num_layers
,
checkpoint_num_layers
=
checkpoint_num_layers
,
...
@@ -161,7 +167,9 @@ class BertModel(MegatronModule):
...
@@ -161,7 +167,9 @@ class BertModel(MegatronModule):
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
hidden_size
,
2
,
init_method
)
self
.
binary_head
=
get_linear_layer
(
hidden_size
,
2
,
init_method
)
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
elif
self
.
add_ict_head
:
self
.
ict_head
=
get_linear_layer
(
hidden_size
,
ict_head_size
,
init_method
)
self
.
_ict_head_key
=
'ict_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
tokentype_ids
=
None
):
...
@@ -170,7 +178,7 @@ class BertModel(MegatronModule):
...
@@ -170,7 +178,7 @@ class BertModel(MegatronModule):
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
position_ids
=
bert_position_ids
(
input_ids
)
position_ids
=
bert_position_ids
(
input_ids
)
if
self
.
add_binary_head
:
if
self
.
add_binary_head
or
self
.
add_ict_head
:
lm_output
,
pooled_output
=
self
.
language_model
(
lm_output
,
pooled_output
=
self
.
language_model
(
input_ids
,
input_ids
,
position_ids
,
position_ids
,
...
@@ -190,6 +198,9 @@ class BertModel(MegatronModule):
...
@@ -190,6 +198,9 @@ class BertModel(MegatronModule):
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
binary_logits
=
self
.
binary_head
(
pooled_output
)
binary_logits
=
self
.
binary_head
(
pooled_output
)
return
lm_logits
,
binary_logits
return
lm_logits
,
binary_logits
elif
self
.
add_ict_head
:
ict_logits
=
self
.
ict_head
(
pooled_output
)
return
lm_logits
,
ict_logits
return
lm_logits
,
None
return
lm_logits
,
None
...
@@ -209,6 +220,9 @@ class BertModel(MegatronModule):
...
@@ -209,6 +220,9 @@ class BertModel(MegatronModule):
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
elif
self
.
add_ict_head
:
state_dict_
[
self
.
_ict_head_key
]
\
=
self
.
ict_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -222,3 +236,7 @@ class BertModel(MegatronModule):
...
@@ -222,3 +236,7 @@ class BertModel(MegatronModule):
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
strict
=
strict
)
elif
self
.
add_ict_head
:
self
.
ict_head
.
load_state_dict
(
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment