Commit 8c7699ce authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446536595
parent 71277b49
...@@ -364,3 +364,97 @@ def extract_spec_norm_kwargs(kwargs): ...@@ -364,3 +364,97 @@ def extract_spec_norm_kwargs(kwargs):
return dict( return dict(
iteration=kwargs.pop("iteration", 1), iteration=kwargs.pop("iteration", 1),
norm_multiplier=kwargs.pop("norm_multiplier", .99)) norm_multiplier=kwargs.pop("norm_multiplier", .99))
class PerQueryDenseHead(tf.keras.layers.Layer):
"""Pooling head used for EncT5 style models.
This module projects each query to use a different projection.
For a input shape= [bs, num_queries, hidden_size], it projects each query to
(features). Ending up with shape= [bs, num_queries, features].
For example, for classification with a few classes, one may use num_queries
as 1 and features as number of classes. For multilabel classification, one
may use num_queries as number of classes and features as 2. So each query
represents a binary classification of one label.
"""
def __init__(self,
num_queries: int,
features: int,
use_bias: bool = False,
kernel_initializer: str = "glorot_uniform",
**kwargs):
"""Initializes the `PerQueryDenseHead`.
Args:
num_queries: number of queries (the learnable embeddings in the input
sequences) from the decoder.
features: int with numbers of output features. Each query with be
projected to this number with a different projection.
use_bias: whether to add a bias to the output.
kernel_initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super().__init__(**kwargs)
self.num_queries = num_queries
self.features = features
self.use_bias = use_bias
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
# Hidden size.
last_dim = tf.compat.dimension_value(input_shape[-1])
self.hidden_size = last_dim
self.kernel = self.add_weight(
"kernel",
shape=[self.num_queries, last_dim, self.features],
initializer=self.kernel_initializer,
dtype=self.dtype,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(
"bias",
shape=[
self.num_queries,
self.features,
],
dtype=self.dtype,
trainable=True)
else:
self.bias = None
def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""Implements call().
Args:
inputs: a rank-3 Tensor of shape= [bs, num_queries, hidden_size].
Returns:
A Tensor, shape= [batch size, num_queries, features].
"""
outputs = tf.einsum("bqh,qhf->bqf", inputs, self.kernel)
if self.use_bias:
outputs += self.bias
return outputs
def get_config(self):
config = {
"num_queries":
self.num_queries,
"features":
self.features,
"kernel_initializer":
tf.keras.activations.serialize(self.kernel_initializer),
}
config.update(super(PerQueryDenseHead, self).get_config())
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
...@@ -199,5 +199,29 @@ class GaussianProcessClassificationHead(tf.test.TestCase, ...@@ -199,5 +199,29 @@ class GaussianProcessClassificationHead(tf.test.TestCase,
self.assertEqual(layer_config["norm_multiplier"], 1.) self.assertEqual(layer_config["norm_multiplier"], 1.)
self.assertEqual(layer_config["num_inducing"], 512) self.assertEqual(layer_config["num_inducing"], 512)
class PerQueryDenseHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("single_query", 1, 3, False),
("multi_queries", 10, 2, False),
("with_bias", 10, 2, True))
def test_layer_invocation(self, num_queries, features, use_bias):
batch_size = 5
hidden_size = 10
layer = cls_head.PerQueryDenseHead(
num_queries=num_queries, features=features, use_bias=use_bias)
inputs = tf.zeros(
shape=(batch_size, num_queries, hidden_size), dtype=tf.float32)
outputs = layer(inputs)
self.assertEqual(outputs.shape, [batch_size, num_queries, features])
def test_layer_serialization(self):
layer = cls_head.PerQueryDenseHead(
num_queries=10, features=2, use_bias=True)
new_layer = cls_head.PerQueryDenseHead.from_config(layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(layer.get_config(), new_layer.get_config())
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment