Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
026a7880
Commit
026a7880
authored
Mar 24, 2021
by
Jeremiah Liu
Committed by
A. Unique TensorFlower
Mar 24, 2021
Browse files
Adds unit tests for using `cls_head` in `BertClassifier`.
PiperOrigin-RevId: 364952108
parent
08273bc2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
9 deletions
+21
-9
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+5
-4
official/nlp/modeling/models/bert_classifier_test.py
official/nlp/modeling/models/bert_classifier_test.py
+16
-5
No files found.
official/nlp/modeling/models/bert_classifier.py
View file @
026a7880
...
@@ -45,8 +45,8 @@ class BertClassifier(tf.keras.Model):
...
@@ -45,8 +45,8 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
encoder.
cls_head: (Optional) The layer instance to use for the classifier head
cls_head: (Optional) The layer instance to use for the classifier head
.
.
It should take in the output from network and produce the final logits.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler') will be ignored.
"""
"""
...
@@ -62,7 +62,6 @@ class BertClassifier(tf.keras.Model):
...
@@ -62,7 +62,6 @@ class BertClassifier(tf.keras.Model):
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
initializer
=
initializer
self
.
initializer
=
initializer
self
.
use_encoder_pooler
=
use_encoder_pooler
self
.
use_encoder_pooler
=
use_encoder_pooler
self
.
cls_head
=
cls_head
# We want to use the inputs of the passed network as the inputs to this
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
# Model. To do this, we need to keep a handle to the network inputs for use
...
@@ -107,6 +106,8 @@ class BertClassifier(tf.keras.Model):
...
@@ -107,6 +106,8 @@ class BertClassifier(tf.keras.Model):
super
(
BertClassifier
,
self
).
__init__
(
super
(
BertClassifier
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
self
.
_network
=
network
self
.
_network
=
network
self
.
_cls_head
=
cls_head
config_dict
=
self
.
_make_config_dict
()
config_dict
=
self
.
_make_config_dict
()
# We are storing the config dict as a namedtuple here to ensure checkpoint
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# compatibility with an earlier version of this model which did not track
...
@@ -138,5 +139,5 @@ class BertClassifier(tf.keras.Model):
...
@@ -138,5 +139,5 @@ class BertClassifier(tf.keras.Model):
'num_classes'
:
self
.
num_classes
,
'num_classes'
:
self
.
num_classes
,
'initializer'
:
self
.
initializer
,
'initializer'
:
self
.
initializer
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'cls_head'
:
self
.
cls_head
,
'cls_head'
:
self
.
_
cls_head
,
}
}
official/nlp/modeling/models/bert_classifier_test.py
View file @
026a7880
...
@@ -18,6 +18,7 @@ from absl.testing import parameterized
...
@@ -18,6 +18,7 @@ from absl.testing import parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
bert_classifier
from
official.nlp.modeling.models
import
bert_classifier
...
@@ -53,16 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase):
...
@@ -53,16 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape
=
[
None
,
num_classes
]
expected_classification_shape
=
[
None
,
num_classes
]
self
.
assertAllEqual
(
expected_classification_shape
,
cls_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_classification_shape
,
cls_outs
.
shape
.
as_list
())
@
parameterized
.
parameters
(
1
,
2
)
@
parameterized
.
named_parameters
(
def
test_bert_trainer_tensor_call
(
self
,
num_classes
):
(
'single_cls'
,
1
,
False
),
(
'2_cls'
,
2
,
False
),
(
'single_cls_custom_head'
,
1
,
True
),
(
'2_cls_custom_head'
,
2
,
True
))
def
test_bert_trainer_tensor_call
(
self
,
num_classes
,
use_custom_head
):
"""Validate that the Keras object can be invoked."""
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
# a short sequence_length for convenience.)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
cls_head
=
layers
.
GaussianProcessClassificationHead
(
inner_dim
=
0
,
num_classes
=
num_classes
)
if
use_custom_head
else
None
# Create a BERT trainer with the created network.
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
test_network
,
num_classes
=
num_classes
)
test_network
,
num_classes
=
num_classes
,
cls_head
=
cls_head
)
# Create a set of 2-dimensional data tensors to feed into the model.
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
word_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
...
@@ -74,7 +81,11 @@ class BertClassifierTest(keras_parameterized.TestCase):
...
@@ -74,7 +81,11 @@ class BertClassifierTest(keras_parameterized.TestCase):
# too complex: this simply ensures we're not hitting runtime errors.)
# too complex: this simply ensures we're not hitting runtime errors.)
_
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
])
_
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
])
def
test_serialize_deserialize
(
self
):
@
parameterized
.
named_parameters
(
(
'default_cls_head'
,
None
),
(
'sngp_cls_head'
,
layers
.
GaussianProcessClassificationHead
(
inner_dim
=
0
,
num_classes
=
4
)))
def
test_serialize_deserialize
(
self
,
cls_head
):
"""Validate that the BERT trainer can be serialized and deserialized."""
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
# a short sequence_length for convenience.)
...
@@ -84,7 +95,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
...
@@ -84,7 +95,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
# are different, so we can catch any serialization mismatches.)
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
test_network
,
num_classes
=
4
,
initializer
=
'zeros'
)
test_network
,
num_classes
=
4
,
initializer
=
'zeros'
,
cls_head
=
cls_head
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
config
=
bert_trainer_model
.
get_config
()
config
=
bert_trainer_model
.
get_config
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment