Commit 026a7880 authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Adds unit tests for using `cls_head` in `BertClassifier`.

PiperOrigin-RevId: 364952108
parent 08273bc2
......@@ -45,8 +45,8 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
cls_head: (Optional) The layer instance to use for the classifier head
. It should take in the output from network and produce the final logits.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
"""
......@@ -62,7 +62,6 @@ class BertClassifier(tf.keras.Model):
self.num_classes = num_classes
self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler
self.cls_head = cls_head
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -107,6 +106,8 @@ class BertClassifier(tf.keras.Model):
super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
self._network = network
self._cls_head = cls_head
config_dict = self._make_config_dict()
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
......@@ -138,5 +139,5 @@ class BertClassifier(tf.keras.Model):
'num_classes': self.num_classes,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self.cls_head,
'cls_head': self._cls_head,
}
......@@ -18,6 +18,7 @@ from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_classifier
......@@ -53,16 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
@parameterized.parameters(1, 2)
def test_bert_trainer_tensor_call(self, num_classes):
@parameterized.named_parameters(
('single_cls', 1, False),
('2_cls', 2, False),
('single_cls_custom_head', 1, True),
('2_cls_custom_head', 2, True))
def test_bert_trainer_tensor_call(self, num_classes, use_custom_head):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
cls_head = layers.GaussianProcessClassificationHead(
inner_dim=0, num_classes=num_classes) if use_custom_head else None
# Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=num_classes)
test_network, num_classes=num_classes, cls_head=cls_head)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
......@@ -74,7 +81,11 @@ class BertClassifierTest(keras_parameterized.TestCase):
# too complex: this simply ensures we're not hitting runtime errors.)
_ = bert_trainer_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self):
@parameterized.named_parameters(
('default_cls_head', None),
('sngp_cls_head', layers.GaussianProcessClassificationHead(
inner_dim=0, num_classes=4)))
def test_serialize_deserialize(self, cls_head):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
......@@ -84,7 +95,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=4, initializer='zeros')
test_network, num_classes=4, initializer='zeros', cls_head=cls_head)
# Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment