Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f7852565
"examples/mxnet/vscode:/vscode.git/clone" did not exist on "12d706300cba4d9ec25cfa1075ab4d2703dd89f0"
Commit
f7852565
authored
Apr 16, 2020
by
Sergey Mironov
Browse files
Make BertPretrainer to accept embedding_table explicitly
parent
31da2245
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
3 deletions
+13
-3
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+1
-0
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+5
-2
official/nlp/modeling/networks/masked_lm.py
official/nlp/modeling/networks/masked_lm.py
+7
-1
No files found.
official/nlp/bert/bert_models.py
View file @
f7852565
...
@@ -212,6 +212,7 @@ def pretrain_model(bert_config,
...
@@ -212,6 +212,7 @@ def pretrain_model(bert_config,
stddev
=
bert_config
.
initializer_range
)
stddev
=
bert_config
.
initializer_range
)
pretrainer_model
=
models
.
BertPretrainer
(
pretrainer_model
=
models
.
BertPretrainer
(
network
=
transformer_encoder
,
network
=
transformer_encoder
,
embedding_table
=
transformer_encoder
.
get_embedding_table
(),
num_classes
=
2
,
# The next sentence prediction label has two classes.
num_classes
=
2
,
# The next sentence prediction label has two classes.
num_token_predictions
=
max_predictions_per_seq
,
num_token_predictions
=
max_predictions_per_seq
,
initializer
=
initializer
,
initializer
=
initializer
,
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
f7852565
...
@@ -39,14 +39,15 @@ class BertPretrainer(tf.keras.Model):
...
@@ -39,14 +39,15 @@ class BertPretrainer(tf.keras.Model):
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output.
table via a "get_embedding_table" method.
num_classes: Number of classes to predict from the classification network.
num_classes: Number of classes to predict from the classification network.
num_token_predictions: Number of tokens to predict from the masked LM.
num_token_predictions: Number of tokens to predict from the masked LM.
activation: The activation (if any) to use in the masked LM and
activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
classification networks. If None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
classification networks. Defaults to a Glorot uniform initializer.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
output: The output style for this network. Can be either 'logits' or
output: The output style for this network. Can be either 'logits' or
'predictions'.
'predictions'.
"""
"""
...
@@ -58,6 +59,7 @@ class BertPretrainer(tf.keras.Model):
...
@@ -58,6 +59,7 @@ class BertPretrainer(tf.keras.Model):
activation
=
None
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
output
=
'logits'
,
embedding_table
=
None
,
**
kwargs
):
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_config
=
{
self
.
_config
=
{
...
@@ -100,6 +102,7 @@ class BertPretrainer(tf.keras.Model):
...
@@ -100,6 +102,7 @@ class BertPretrainer(tf.keras.Model):
num_predictions
=
num_token_predictions
,
num_predictions
=
num_token_predictions
,
input_width
=
sequence_output
.
shape
[
-
1
],
input_width
=
sequence_output
.
shape
[
-
1
],
source_network
=
network
,
source_network
=
network
,
embedding_table
=
embedding_table
,
activation
=
activation
,
activation
=
activation
,
initializer
=
initializer
,
initializer
=
initializer
,
output
=
output
,
output
=
output
,
...
...
official/nlp/modeling/networks/masked_lm.py
View file @
f7852565
...
@@ -37,6 +37,8 @@ class MaskedLM(network.Network):
...
@@ -37,6 +37,8 @@ class MaskedLM(network.Network):
num_predictions: The number of predictions to make per sequence.
num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the
source_network: The network with the embedding layer to use for the
embedding layer.
embedding layer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
a Glorot uniform initializer.
...
@@ -48,12 +50,16 @@ class MaskedLM(network.Network):
...
@@ -48,12 +50,16 @@ class MaskedLM(network.Network):
input_width
,
input_width
,
num_predictions
,
num_predictions
,
source_network
,
source_network
,
embedding_table
=
None
,
activation
=
None
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
output
=
'logits'
,
**
kwargs
):
**
kwargs
):
embedding_table
=
source_network
.
get_embedding_table
()
if
embedding_table
is
None
:
embedding_table
=
source_network
.
get_embedding_table
()
vocab_size
,
hidden_size
=
embedding_table
.
shape
vocab_size
,
hidden_size
=
embedding_table
.
shape
sequence_data
=
tf
.
keras
.
layers
.
Input
(
sequence_data
=
tf
.
keras
.
layers
.
Input
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment