Commit 2cf59d12 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

adds a ClassificationHead layer

PiperOrigin-RevId: 313906815
parent b45dd807
# Layers # Layers
Layers are the fundamental building blocks for NLP models. They can be used to Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models. assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using tf.einsum. This layer contains the einsum op, the associated weight, and the * [DenseEinsum](dense_einsum.py) implements a feedforward network using
logic required to generate the einsum expression for the given initialization tf.einsum. This layer contains the einsum op, the associated weight, and the
parameters. logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
If `from_tensor` and `to_tensor` are the same, then this is self-attention. `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used * [CachedAttention](attention.py) implements an attention layer with cache
for auto-agressive decoding. used for auto-agressive decoding.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking * [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436). heads attention, as decribed in
["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
* [Transformer](transformer.py) implements an optionally masked transformer as * [Transformer](transformer.py) implements an optionally masked transformer as
described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with ReZero * [ReZeroTransformer](rezero_transformer.py) implements Transformer with
described in ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887). ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models. * [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding
lookups designed for TPU-based models.
* [PositionalEmbedding](position_embedding.py) creates a positional embedding * [PositionalEmbedding](position_embedding.py) creates a positional embedding
as described in ["BERT: Pre-training as described in ["BERT: Pre-training of Deep Bidirectional Transformers for
of Deep Bidirectional Transformers for Language Understanding"] Language Understanding"](https://arxiv.org/abs/1810.04805).
(https://arxiv.org/abs/1810.04805).
* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from
a 2D tensor mask.
* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from a 2D tensor mask. * [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional
masking input. If no mask is provided to this layer, it performs a standard
softmax; however, if a mask tensor is applied (which should be 1 in
positions where the data should be allowed through, and 0 where the data
should be masked), the output will have masked positions set to
approximately zero.
* [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional masking input. If no mask is provided to this layer, it performs a standard softmax; however, if a mask tensor is applied (which should be 1 in positions where the data should be allowed through, and 0 where the data should be masked), the output will have masked positions set to approximately zero. * [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks.
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Layers package definition.""" """Layers package definition."""
from official.nlp.modeling.layers.attention import * # pylint: disable=wildcard-import # pylint: disable=wildcard-import
from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Classification head layer which is common used with sequence encoders."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
class ClassificationHead(tf.keras.layers.Layer):
"""Pooling head for sentence-level classification tasks."""
def __init__(self,
inner_dim,
num_classes,
cls_token_idx=0,
activation="tanh",
dropout_rate=0.0,
initializer="glorot_uniform",
**kwargs):
"""Initializes the `ClassificationHead`.
Args:
inner_dim: The dimensionality of inner projection layer.
num_classes: Number of output classes.
cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation.
dropout_rate: Dropout probability.
initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super(ClassificationHead, self).__init__(**kwargs)
self.dropout_rate = dropout_rate
self.inner_dim = inner_dim
self.num_classes = num_classes
self.activation = tf_utils.get_activation(activation)
self.initializer = tf.keras.initializers.get(initializer)
self.cls_token_idx = cls_token_idx
self.dense = tf.keras.layers.Dense(
units=inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits")
def call(self, features):
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
def get_config(self):
config = {
"dropout_rate": self.dropout_rate,
"num_classes": self.num_classes,
"inner_dim": self.inner_dim,
"activation": tf.keras.activations.serialize(self.activation),
"initializer": tf.keras.initializers.serialize(self.initializer),
}
config.update(super(ClassificationHead, self).get_config())
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
return {self.dense.name: self.dense}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cls_head."""
import tensorflow as tf
from official.nlp.modeling.layers import cls_head
class ClassificationHead(tf.test.TestCase):
def test_layer_invocation(self):
test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
output = test_layer(features)
self.assertAllClose(output, [[0., 0.], [0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"])
def test_layer_serialization(self):
layer = cls_head.ClassificationHead(10, 2)
new_layer = cls_head.ClassificationHead.from_config(layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(layer.get_config(), new_layer.get_config())
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment