Commit 6d28faff authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 392968271
parent 33598c45
......@@ -64,14 +64,28 @@ class ClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits")
def call(self, features):
def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call().
Args:
features: a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
only_project: a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
Returns:
a Tensor, if only_project is True, shape= [batch size, hidden size].
If only_project is False, shape= [batch size, num classes].
"""
if not self.inner_dim:
x = features
else:
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
x = self.dropout(x)
if only_project:
return x
x = self.dropout(x)
x = self.out_proj(x)
return x
......@@ -142,12 +156,27 @@ class MultiClsHeads(tf.keras.layers.Layer):
units=num_classes, kernel_initializer=self.initializer,
name=name))
def call(self, features):
def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call().
Args:
features: a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
only_project: a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
Returns:
If only_project is True, a Tensor with shape= [batch size, hidden size].
If only_project is False, a dictionary of Tensors.
"""
if not self.inner_dim:
x = features
else:
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
if only_project:
return x
x = self.dropout(x)
outputs = {}
......
......@@ -39,6 +39,8 @@ class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(output, [[0., 0.], [0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"])
outputs = test_layer(features, only_project=True)
self.assertEqual(outputs.shape, (2, 5))
def test_layer_serialization(self):
layer = cls_head.ClassificationHead(10, 2)
......@@ -71,6 +73,9 @@ class MultiClsHeadsTest(tf.test.TestCase, parameterized.TestCase):
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense", "foo", "bar"])
outputs = test_layer(features, only_project=True)
self.assertEqual(outputs.shape, (2, 5))
def test_layer_serialization(self):
cls_list = [("foo", 2), ("bar", 3)]
test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment