Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
52b16a1a
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "ce8ce8256728d4211e50e0db72195073ad37127c"
Commit
52b16a1a
authored
Sep 21, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Sep 21, 2020
Browse files
Internal change
PiperOrigin-RevId: 332806032
parent
71a2fc91
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
11 deletions
+49
-11
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+20
-8
official/nlp/modeling/models/bert_pretrainer_test.py
official/nlp/modeling/models/bert_pretrainer_test.py
+29
-3
No files found.
official/nlp/modeling/models/bert_pretrainer.py
View file @
52b16a1a
...
@@ -161,8 +161,9 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -161,8 +161,9 @@ class BertPretrainerV2(tf.keras.Model):
name: The name of the model.
name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary.
dictionary.
Outputs: A dictionary of `lm_output` and classification head outputs keyed by
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names.
head names, and also outputs from `encoder_network`, keyed by
`pooled_output`, `sequence_output` and `encoder_outputs` (if any).
"""
"""
def
__init__
(
def
__init__
(
...
@@ -180,21 +181,32 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -180,21 +181,32 @@ class BertPretrainerV2(tf.keras.Model):
'classification_heads'
:
classification_heads
,
'classification_heads'
:
classification_heads
,
'name'
:
name
,
'name'
:
name
,
}
}
self
.
encoder_network
=
encoder_network
self
.
encoder_network
=
encoder_network
inputs
=
copy
.
copy
(
self
.
encoder_network
.
inputs
)
inputs
=
copy
.
copy
(
self
.
encoder_network
.
inputs
)
outputs
=
self
.
encoder_network
(
inputs
)
outputs
=
dict
()
if
isinstance
(
outputs
,
list
):
encoder_network_outputs
=
self
.
encoder_network
(
inputs
)
sequence_output
=
outputs
[
0
]
if
isinstance
(
encoder_network_outputs
,
list
):
outputs
[
'pooled_output'
]
=
encoder_network_outputs
[
1
]
# When `encoder_network` was instantiated with return_all_encoder_outputs
# set to True, `encoder_network_outputs[0]` is a list containing
# all transformer layers' output.
if
isinstance
(
encoder_network_outputs
[
0
],
list
):
outputs
[
'encoder_outputs'
]
=
encoder_network_outputs
[
0
]
outputs
[
'sequence_output'
]
=
encoder_network_outputs
[
0
][
-
1
]
else
:
outputs
[
'sequence_output'
]
=
encoder_network_outputs
[
0
]
elif
isinstance
(
encoder_network_outputs
,
dict
):
outputs
=
encoder_network_outputs
else
:
else
:
sequence_output
=
outputs
[
'sequence_output'
]
raise
ValueError
(
'encoder_network
\'
s output should be either a list '
'or a dict, but got %s'
%
encoder_network_outputs
)
sequence_output
=
outputs
[
'sequence_output'
]
self
.
classification_heads
=
classification_heads
or
[]
self
.
classification_heads
=
classification_heads
or
[]
if
len
(
set
([
cls
.
name
for
cls
in
self
.
classification_heads
]))
!=
len
(
if
len
(
set
([
cls
.
name
for
cls
in
self
.
classification_heads
]))
!=
len
(
self
.
classification_heads
):
self
.
classification_heads
):
raise
ValueError
(
'Classification heads should have unique names.'
)
raise
ValueError
(
'Classification heads should have unique names.'
)
outputs
=
dict
()
self
.
masked_lm
=
layers
.
MaskedLM
(
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
activation
=
mlm_activation
,
...
...
official/nlp/modeling/models/bert_pretrainer_test.py
View file @
52b16a1a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tests for BERT pretrainer model."""
"""Tests for BERT pretrainer model."""
import
itertools
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -108,16 +109,23 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -108,16 +109,23 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
bert_trainer_model
.
get_config
(),
self
.
assertAllEqual
(
bert_trainer_model
.
get_config
(),
new_bert_trainer_model
.
get_config
())
new_bert_trainer_model
.
get_config
())
@
parameterized
.
parameters
(
True
,
False
)
@
parameterized
.
parameters
(
itertools
.
product
(
def
test_bert_pretrainerv2
(
self
,
dict_outputs
):
(
False
,
True
),
(
False
,
True
),
))
def
test_bert_pretrainerv2
(
self
,
dict_outputs
,
return_all_encoder_outputs
):
"""Validate that the Keras object can be created."""
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
vocab_size
=
100
sequence_length
=
512
sequence_length
=
512
hidden_size
=
48
num_layers
=
2
test_network
=
networks
.
BertEncoder
(
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_layers
=
2
,
num_layers
=
num_layers
,
hidden_size
=
hidden_size
,
max_sequence_length
=
sequence_length
,
max_sequence_length
=
sequence_length
,
return_all_encoder_outputs
=
return_all_encoder_outputs
,
dict_outputs
=
dict_outputs
)
dict_outputs
=
dict_outputs
)
# Create a BERT trainer with the created network.
# Create a BERT trainer with the created network.
...
@@ -133,10 +141,28 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -133,10 +141,28 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the inputs. This causes the layer to be built.
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
,
lm_mask
])
outputs
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
,
lm_mask
])
has_encoder_outputs
=
dict_outputs
or
return_all_encoder_outputs
if
has_encoder_outputs
:
self
.
assertSameElements
(
outputs
.
keys
(),
[
'sequence_output'
,
'pooled_output'
,
'lm_output'
,
'encoder_outputs'
])
self
.
assertLen
(
outputs
[
'encoder_outputs'
],
num_layers
)
else
:
self
.
assertSameElements
(
outputs
.
keys
(),
[
'sequence_output'
,
'pooled_output'
,
'lm_output'
])
# Validate that the outputs are of the expected shape.
# Validate that the outputs are of the expected shape.
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
self
.
assertAllEqual
(
expected_lm_shape
,
outputs
[
'lm_output'
].
shape
.
as_list
())
self
.
assertAllEqual
(
expected_lm_shape
,
outputs
[
'lm_output'
].
shape
.
as_list
())
expected_sequence_output_shape
=
[
None
,
sequence_length
,
hidden_size
]
self
.
assertAllEqual
(
expected_sequence_output_shape
,
outputs
[
'sequence_output'
].
shape
.
as_list
())
expected_pooled_output_shape
=
[
None
,
hidden_size
]
self
.
assertAllEqual
(
expected_pooled_output_shape
,
outputs
[
'pooled_output'
].
shape
.
as_list
())
def
test_v2_serialize_deserialize
(
self
):
def
test_v2_serialize_deserialize
(
self
):
"""Validate that the BERT trainer can be serialized and deserialized."""
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer. (Here, we use
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment