Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
840a493a
Commit
840a493a
authored
May 15, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 311773503
parent
bce4604a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+8
-3
No files found.
official/nlp/modeling/models/bert_pretrainer.py
View file @
840a493a
...
@@ -70,12 +70,12 @@ class BertPretrainer(tf.keras.Model):
...
@@ -70,12 +70,12 @@ class BertPretrainer(tf.keras.Model):
'initializer'
:
initializer
,
'initializer'
:
initializer
,
'output'
:
output
,
'output'
:
output
,
}
}
self
.
encoder
=
network
# We want to use the inputs of the passed network as the inputs to this
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use
# Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy
# when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.)
# because we'll be adding another tensor to the copy later.)
network_inputs
=
network
.
inputs
network_inputs
=
self
.
encoder
.
inputs
inputs
=
copy
.
copy
(
network_inputs
)
inputs
=
copy
.
copy
(
network_inputs
)
# Because we have a copy of inputs to create this Model object, we can
# Because we have a copy of inputs to create this Model object, we can
...
@@ -83,8 +83,13 @@ class BertPretrainer(tf.keras.Model):
...
@@ -83,8 +83,13 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use
# Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list
# the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below.
# object contains the additional input added below.
sequence_output
,
cls_output
=
network
(
network_inputs
)
sequence_output
,
cls_output
=
self
.
encoder
(
network_inputs
)
# The encoder network may get outputs from all layers.
if
isinstance
(
sequence_output
,
list
):
sequence_output
=
sequence_output
[
-
1
]
if
isinstance
(
cls_output
,
list
):
cls_output
=
cls_output
[
-
1
]
sequence_output_length
=
sequence_output
.
shape
.
as_list
()[
1
]
sequence_output_length
=
sequence_output
.
shape
.
as_list
()[
1
]
if
sequence_output_length
<
num_token_predictions
:
if
sequence_output_length
<
num_token_predictions
:
raise
ValueError
(
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment