Commit d07c870f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds a MultiClsHeads layer.

PiperOrigin-RevId: 344926166
parent 212299e2
......@@ -42,7 +42,7 @@ class ClassificationHead(tf.keras.layers.Layer):
initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super(ClassificationHead, self).__init__(**kwargs)
super().__init__(**kwargs)
self.dropout_rate = dropout_rate
self.inner_dim = inner_dim
self.num_classes = num_classes
......@@ -68,6 +68,7 @@ class ClassificationHead(tf.keras.layers.Layer):
def get_config(self):
config = {
"cls_token_idx": self.cls_token_idx,
"dropout_rate": self.dropout_rate,
"num_classes": self.num_classes,
"inner_dim": self.inner_dim,
......@@ -84,3 +85,78 @@ class ClassificationHead(tf.keras.layers.Layer):
@property
def checkpoint_items(self):
return {self.dense.name: self.dense}
class MultiClsHeads(tf.keras.layers.Layer):
"""Pooling heads sharing the same pooling stem."""
def __init__(self,
inner_dim,
cls_list,
cls_token_idx=0,
activation="tanh",
dropout_rate=0.0,
initializer="glorot_uniform",
**kwargs):
"""Initializes the `MultiClsHeads`.
Args:
inner_dim: The dimensionality of inner projection layer.
cls_list: a list of pairs of (classification problem name and the numbers
of classes.
cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation.
dropout_rate: Dropout probability.
initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super().__init__(**kwargs)
self.dropout_rate = dropout_rate
self.inner_dim = inner_dim
self.cls_list = cls_list
self.activation = tf_utils.get_activation(activation)
self.initializer = tf.keras.initializers.get(initializer)
self.cls_token_idx = cls_token_idx
self.dense = tf.keras.layers.Dense(
units=inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_projs = []
for name, num_classes in cls_list:
self.out_projs.append(
tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer,
name=name))
def call(self, features):
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
x = self.dropout(x)
outputs = {}
for proj_layer in self.out_projs:
outputs[proj_layer.name] = proj_layer(x)
return outputs
def get_config(self):
config = {
"dropout_rate": self.dropout_rate,
"cls_token_idx": self.cls_token_idx,
"cls_list": self.cls_list,
"inner_dim": self.inner_dim,
"activation": tf.keras.activations.serialize(self.activation),
"initializer": tf.keras.initializers.serialize(self.initializer),
}
config.update(super().get_config())
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
# TODO(hongkuny): add output projects to the checkpoint items.
return {self.dense.name: self.dense}
......@@ -20,7 +20,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import cls_head
class ClassificationHead(tf.test.TestCase):
class ClassificationHeadTest(tf.test.TestCase):
def test_layer_invocation(self):
test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
......@@ -38,5 +38,26 @@ class ClassificationHead(tf.test.TestCase):
self.assertAllEqual(layer.get_config(), new_layer.get_config())
class MultiClsHeadsTest(tf.test.TestCase):
def test_layer_invocation(self):
cls_list = [("foo", 2), ("bar", 3)]
test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
outputs = test_layer(features)
self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]])
self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"])
def test_layer_serialization(self):
cls_list = [("foo", 2), ("bar", 3)]
test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list)
new_layer = cls_head.MultiClsHeads.from_config(test_layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == "__main__":
tf.test.main()
......@@ -183,7 +183,7 @@ class BertPretrainerV2(tf.keras.Model):
dictionary.
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names, and also outputs from `encoder_network`, keyed by
`pooled_output`, `sequence_output` and `encoder_outputs` (if any).
`sequence_output` and `encoder_outputs` (if any).
"""
def __init__(
......@@ -248,7 +248,11 @@ class BertPretrainerV2(tf.keras.Model):
outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output)
cls_outputs = cls_head(sequence_output)
if isinstance(cls_outputs, dict):
outputs.update(cls_outputs)
else:
outputs[cls_head.name] = cls_outputs
return outputs
@property
......
......@@ -110,6 +110,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config())
class BertPretrainerV2Test(keras_parameterized.TestCase):
@parameterized.parameters(itertools.product(
(False, True),
(False, True),
......@@ -175,6 +178,38 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(expected_pooled_output_shape,
outputs['pooled_output'].shape.as_list())
def test_multiple_cls_outputs(self):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
dict_outputs=True)
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network,
classification_heads=[layers.MultiClsHeads(
inner_dim=5, cls_list=[('foo', 2), ('bar', 3)])])
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
inputs = dict(
input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
masked_lm_positions=tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model(inputs)
self.assertEqual(outputs['foo'].shape.as_list(), [None, 2])
self.assertEqual(outputs['bar'].shape.as_list(), [None, 3])
def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment