Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9cbe7fab
Commit
9cbe7fab
authored
Apr 27, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #8403 from stagedml:bert-pretrain-embedding-table
PiperOrigin-RevId: 308649588
parents
7cc0970b
f7852565
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
3 deletions
+11
-3
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+1
-0
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+5
-2
official/nlp/modeling/networks/masked_lm.py
official/nlp/modeling/networks/masked_lm.py
+5
-1
No files found.
official/nlp/bert/bert_models.py
View file @
9cbe7fab
...
...
@@ -212,6 +212,7 @@ def pretrain_model(bert_config,
stddev
=
bert_config
.
initializer_range
)
pretrainer_model
=
models
.
BertPretrainer
(
network
=
transformer_encoder
,
embedding_table
=
transformer_encoder
.
get_embedding_table
(),
num_classes
=
2
,
# The next sentence prediction label has two classes.
num_token_predictions
=
max_predictions_per_seq
,
initializer
=
initializer
,
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
9cbe7fab
...
...
@@ -39,10 +39,11 @@ class BertPretrainer(tf.keras.Model):
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
and a classification output.
num_classes: Number of classes to predict from the classification network.
num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
...
...
@@ -55,6 +56,7 @@ class BertPretrainer(tf.keras.Model):
network
,
num_classes
,
num_token_predictions
,
embedding_table
=
None
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
...
...
@@ -100,6 +102,7 @@ class BertPretrainer(tf.keras.Model):
num_predictions
=
num_token_predictions
,
input_width
=
sequence_output
.
shape
[
-
1
],
source_network
=
network
,
embedding_table
=
embedding_table
,
activation
=
activation
,
initializer
=
initializer
,
output
=
output
,
...
...
official/nlp/modeling/networks/masked_lm.py
View file @
9cbe7fab
...
...
@@ -37,6 +37,8 @@ class MaskedLM(network.Network):
num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the
embedding layer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
...
...
@@ -48,12 +50,14 @@ class MaskedLM(network.Network):
input_width
,
num_predictions
,
source_network
,
embedding_table
=
None
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
**
kwargs
):
embedding_table
=
source_network
.
get_embedding_table
()
if
embedding_table
is
None
:
embedding_table
=
source_network
.
get_embedding_table
()
vocab_size
,
hidden_size
=
embedding_table
.
shape
sequence_data
=
tf
.
keras
.
layers
.
Input
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment