"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "9c7af8dfefb72d98f166c7477d483829358a06ff"
Commit 6d28faff authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 392968271
parent 33598c45
...@@ -59,19 +59,33 @@ class ClassificationHead(tf.keras.layers.Layer): ...@@ -59,19 +59,33 @@ class ClassificationHead(tf.keras.layers.Layer):
activation=self.activation, activation=self.activation,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
name="pooler_dense") name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense( self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits") units=num_classes, kernel_initializer=self.initializer, name="logits")
def call(self, features): def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call().
Args:
features: a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
only_project: a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
Returns:
a Tensor, if only_project is True, shape= [batch size, hidden size].
If only_project is False, shape= [batch size, num classes].
"""
if not self.inner_dim: if not self.inner_dim:
x = features x = features
else: else:
x = features[:, self.cls_token_idx, :] # take <CLS> token. x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x) x = self.dense(x)
x = self.dropout(x)
if only_project:
return x
x = self.dropout(x)
x = self.out_proj(x) x = self.out_proj(x)
return x return x
...@@ -134,7 +148,7 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -134,7 +148,7 @@ class MultiClsHeads(tf.keras.layers.Layer):
activation=self.activation, activation=self.activation,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
name="pooler_dense") name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_projs = [] self.out_projs = []
for name, num_classes in cls_list: for name, num_classes in cls_list:
self.out_projs.append( self.out_projs.append(
...@@ -142,13 +156,28 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -142,13 +156,28 @@ class MultiClsHeads(tf.keras.layers.Layer):
units=num_classes, kernel_initializer=self.initializer, units=num_classes, kernel_initializer=self.initializer,
name=name)) name=name))
def call(self, features): def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call().
Args:
features: a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
only_project: a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
Returns:
If only_project is True, a Tensor with shape= [batch size, hidden size].
If only_project is False, a dictionary of Tensors.
"""
if not self.inner_dim: if not self.inner_dim:
x = features x = features
else: else:
x = features[:, self.cls_token_idx, :] # take <CLS> token. x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x) x = self.dense(x)
x = self.dropout(x)
if only_project:
return x
x = self.dropout(x)
outputs = {} outputs = {}
for proj_layer in self.out_projs: for proj_layer in self.out_projs:
......
...@@ -39,6 +39,8 @@ class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase): ...@@ -39,6 +39,8 @@ class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(output, [[0., 0.], [0., 0.]]) self.assertAllClose(output, [[0., 0.], [0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(), self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"]) ["pooler_dense"])
outputs = test_layer(features, only_project=True)
self.assertEqual(outputs.shape, (2, 5))
def test_layer_serialization(self): def test_layer_serialization(self):
layer = cls_head.ClassificationHead(10, 2) layer = cls_head.ClassificationHead(10, 2)
...@@ -71,6 +73,9 @@ class MultiClsHeadsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -71,6 +73,9 @@ class MultiClsHeadsTest(tf.test.TestCase, parameterized.TestCase):
self.assertSameElements(test_layer.checkpoint_items.keys(), self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense", "foo", "bar"]) ["pooler_dense", "foo", "bar"])
outputs = test_layer(features, only_project=True)
self.assertEqual(outputs.shape, (2, 5))
def test_layer_serialization(self): def test_layer_serialization(self):
cls_list = [("foo", 2), ("bar", 3)] cls_list = [("foo", 2), ("bar", 3)]
test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list) test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment