Commit 23ef9155 authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Deprecate `network.Classification` from `BertClassifier`.

PiperOrigin-RevId: 364722225
parent 0674ba0f
...@@ -36,7 +36,8 @@ class ClassificationHead(tf.keras.layers.Layer): ...@@ -36,7 +36,8 @@ class ClassificationHead(tf.keras.layers.Layer):
"""Initializes the `ClassificationHead`. """Initializes the `ClassificationHead`.
Args: Args:
inner_dim: The dimensionality of inner projection layer. inner_dim: The dimensionality of inner projection layer. If 0 or `None`
then only the output projection layer is created.
num_classes: Number of output classes. num_classes: Number of output classes.
cls_token_idx: The index inside the sequence to pool. cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation. activation: Dense layer activation.
...@@ -52,19 +53,25 @@ class ClassificationHead(tf.keras.layers.Layer): ...@@ -52,19 +53,25 @@ class ClassificationHead(tf.keras.layers.Layer):
self.initializer = tf.keras.initializers.get(initializer) self.initializer = tf.keras.initializers.get(initializer)
self.cls_token_idx = cls_token_idx self.cls_token_idx = cls_token_idx
self.dense = tf.keras.layers.Dense( if self.inner_dim:
units=inner_dim, self.dense = tf.keras.layers.Dense(
activation=self.activation, units=self.inner_dim,
kernel_initializer=self.initializer, activation=self.activation,
name="pooler_dense") kernel_initializer=self.initializer,
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense( self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits") units=num_classes, kernel_initializer=self.initializer, name="logits")
def call(self, features): def call(self, features):
x = features[:, self.cls_token_idx, :] # take <CLS> token. if not self.inner_dim:
x = self.dense(x) x = features
x = self.dropout(x) else:
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
x = self.dropout(x)
x = self.out_proj(x) x = self.out_proj(x)
return x return x
...@@ -103,7 +110,8 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -103,7 +110,8 @@ class MultiClsHeads(tf.keras.layers.Layer):
"""Initializes the `MultiClsHeads`. """Initializes the `MultiClsHeads`.
Args: Args:
inner_dim: The dimensionality of inner projection layer. inner_dim: The dimensionality of inner projection layer. If 0 or `None`
then only the output projection layer is created.
cls_list: a list of pairs of (classification problem name and the numbers cls_list: a list of pairs of (classification problem name and the numbers
of classes. of classes.
cls_token_idx: The index inside the sequence to pool. cls_token_idx: The index inside the sequence to pool.
...@@ -120,12 +128,13 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -120,12 +128,13 @@ class MultiClsHeads(tf.keras.layers.Layer):
self.initializer = tf.keras.initializers.get(initializer) self.initializer = tf.keras.initializers.get(initializer)
self.cls_token_idx = cls_token_idx self.cls_token_idx = cls_token_idx
self.dense = tf.keras.layers.Dense( if self.inner_dim:
units=inner_dim, self.dense = tf.keras.layers.Dense(
activation=self.activation, units=inner_dim,
kernel_initializer=self.initializer, activation=self.activation,
name="pooler_dense") kernel_initializer=self.initializer,
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_projs = [] self.out_projs = []
for name, num_classes in cls_list: for name, num_classes in cls_list:
self.out_projs.append( self.out_projs.append(
...@@ -134,9 +143,13 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -134,9 +143,13 @@ class MultiClsHeads(tf.keras.layers.Layer):
name=name)) name=name))
def call(self, features): def call(self, features):
x = features[:, self.cls_token_idx, :] # take <CLS> token. if not self.inner_dim:
x = self.dense(x) x = features
x = self.dropout(x) else:
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
x = self.dropout(x)
outputs = {} outputs = {}
for proj_layer in self.out_projs: for proj_layer in self.out_projs:
outputs[proj_layer.name] = proj_layer(x) outputs[proj_layer.name] = proj_layer(x)
...@@ -195,7 +208,8 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -195,7 +208,8 @@ class GaussianProcessClassificationHead(ClassificationHead):
"""Initializes the `GaussianProcessClassificationHead`. """Initializes the `GaussianProcessClassificationHead`.
Args: Args:
inner_dim: The dimensionality of inner projection layer. inner_dim: The dimensionality of inner projection layer. If 0 or `None`
then only the output projection layer is created.
num_classes: Number of output classes. num_classes: Number of output classes.
cls_token_idx: The index inside the sequence to pool. cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation. activation: Dense layer activation.
...@@ -220,8 +234,8 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -220,8 +234,8 @@ class GaussianProcessClassificationHead(ClassificationHead):
initializer=initializer, initializer=initializer,
**kwargs) **kwargs)
# Applies spectral normalization to the pooler layer. # Applies spectral normalization to the dense pooler layer.
if use_spec_norm: if self.use_spec_norm and hasattr(self, "dense"):
self.dense = spectral_normalization.SpectralNormalization( self.dense = spectral_normalization.SpectralNormalization(
self.dense, inhere_layer_name=True, **self.spec_norm_kwargs) self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)
......
...@@ -13,13 +13,24 @@ ...@@ -13,13 +13,24 @@
# limitations under the License. # limitations under the License.
"""Tests for cls_head.""" """Tests for cls_head."""
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import cls_head from official.nlp.modeling.layers import cls_head
class ClassificationHeadTest(tf.test.TestCase): class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("no_pooler_layer", 0, 2),
("has_pooler_layer", 5, 4))
def test_pooler_layer(self, inner_dim, num_weights_expected):
test_layer = cls_head.ClassificationHead(inner_dim=inner_dim, num_classes=2)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
_ = test_layer(features)
num_weights_observed = len(test_layer.get_weights())
self.assertEqual(num_weights_observed, num_weights_expected)
def test_layer_invocation(self): def test_layer_invocation(self):
test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2) test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
...@@ -37,7 +48,18 @@ class ClassificationHeadTest(tf.test.TestCase): ...@@ -37,7 +48,18 @@ class ClassificationHeadTest(tf.test.TestCase):
self.assertAllEqual(layer.get_config(), new_layer.get_config()) self.assertAllEqual(layer.get_config(), new_layer.get_config())
class MultiClsHeadsTest(tf.test.TestCase): class MultiClsHeadsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("no_pooler_layer", 0, 4),
("has_pooler_layer", 5, 6))
def test_pooler_layer(self, inner_dim, num_weights_expected):
cls_list = [("foo", 2), ("bar", 3)]
test_layer = cls_head.MultiClsHeads(inner_dim=inner_dim, cls_list=cls_list)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
_ = test_layer(features)
num_weights_observed = len(test_layer.get_weights())
self.assertEqual(num_weights_observed, num_weights_expected)
def test_layer_invocation(self): def test_layer_invocation(self):
cls_list = [("foo", 2), ("bar", 3)] cls_list = [("foo", 2), ("bar", 3)]
...@@ -58,13 +80,31 @@ class MultiClsHeadsTest(tf.test.TestCase): ...@@ -58,13 +80,31 @@ class MultiClsHeadsTest(tf.test.TestCase):
self.assertAllEqual(test_layer.get_config(), new_layer.get_config()) self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
class GaussianProcessClassificationHead(tf.test.TestCase): class GaussianProcessClassificationHead(tf.test.TestCase,
parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.spec_norm_kwargs = dict(norm_multiplier=1.,) self.spec_norm_kwargs = dict(norm_multiplier=1.,)
self.gp_layer_kwargs = dict(num_inducing=512) self.gp_layer_kwargs = dict(num_inducing=512)
@parameterized.named_parameters(("no_pooler_layer", 0, 7),
("has_pooler_layer", 5, 11))
def test_pooler_layer(self, inner_dim, num_weights_expected):
test_layer = cls_head.GaussianProcessClassificationHead(
inner_dim=inner_dim,
num_classes=2,
use_spec_norm=True,
use_gp_layer=True,
initializer="zeros",
**self.spec_norm_kwargs,
**self.gp_layer_kwargs)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
_ = test_layer(features)
num_weights_observed = len(test_layer.get_weights())
self.assertEqual(num_weights_observed, num_weights_expected)
def test_layer_invocation(self): def test_layer_invocation(self):
test_layer = cls_head.GaussianProcessClassificationHead( test_layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5, inner_dim=5,
......
...@@ -18,7 +18,6 @@ import collections ...@@ -18,7 +18,6 @@ import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
...@@ -46,6 +45,10 @@ class BertClassifier(tf.keras.Model): ...@@ -46,6 +45,10 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder. encoder.
cls_head: (Optional) The layer instance to use for the classifier head
. It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
""" """
def __init__(self, def __init__(self,
...@@ -54,7 +57,12 @@ class BertClassifier(tf.keras.Model): ...@@ -54,7 +57,12 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True, use_encoder_pooler=True,
cls_head=None,
**kwargs): **kwargs):
self.num_classes = num_classes
self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler
self.cls_head = cls_head
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -66,31 +74,28 @@ class BertClassifier(tf.keras.Model): ...@@ -66,31 +74,28 @@ class BertClassifier(tf.keras.Model):
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
outputs = network(inputs) outputs = network(inputs)
if isinstance(outputs, list): if isinstance(outputs, list):
cls_output = outputs[1] cls_inputs = outputs[1]
else: else:
cls_output = outputs['pooled_output'] cls_inputs = outputs['pooled_output']
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) cls_inputs = tf.keras.layers.Dropout(rate=dropout_rate)(cls_inputs)
classifier = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output='logits',
name='sentence_prediction')
predictions = classifier(cls_output)
else: else:
outputs = network(inputs) outputs = network(inputs)
if isinstance(outputs, list): if isinstance(outputs, list):
sequence_output = outputs[0] cls_inputs = outputs[0]
else: else:
sequence_output = outputs['sequence_output'] cls_inputs = outputs['sequence_output']
if cls_head:
classifier = cls_head
else:
classifier = layers.ClassificationHead( classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1], inner_dim=0 if use_encoder_pooler else cls_inputs.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
name='sentence_prediction') name='sentence_prediction')
predictions = classifier(sequence_output)
predictions = classifier(cls_inputs)
# b/164516224 # b/164516224
# Once we've created the network using the Functional API, we call # Once we've created the network using the Functional API, we call
...@@ -102,13 +107,7 @@ class BertClassifier(tf.keras.Model): ...@@ -102,13 +107,7 @@ class BertClassifier(tf.keras.Model):
super(BertClassifier, self).__init__( super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=predictions, **kwargs)
self._network = network self._network = network
config_dict = { config_dict = self._make_config_dict()
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint # We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track # compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which # the config dict attribute. TF does not track immutable attrs which
...@@ -132,3 +131,12 @@ class BertClassifier(tf.keras.Model): ...@@ -132,3 +131,12 @@ class BertClassifier(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
def _make_config_dict(self):
return {
'network': self._network,
'num_classes': self.num_classes,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self.cls_head,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment