Commit 2cf59d12 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

adds a ClassificationHead layer

PiperOrigin-RevId: 313906815
parent b45dd807
# Layers
Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given initialization
parameters.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
`from_tensor` and `to_tensor` are the same, then this is self-attention.
* [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding.
* [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in
["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
* [Transformer](transformer.py) implements an optionally masked transformer as
described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [Transformer](transformer.py) implements an optionally masked transformer as
described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with ReZero
described in ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding
lookups designed for TPU-based models.
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models.
* [PositionalEmbedding](position_embedding.py) creates a positional embedding
as described in ["BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding"](https://arxiv.org/abs/1810.04805).
* [PositionalEmbedding](position_embedding.py) creates a positional embedding
as described in ["BERT: Pre-training
of Deep Bidirectional Transformers for Language Understanding"]
(https://arxiv.org/abs/1810.04805).
* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from
a 2D tensor mask.
* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from a 2D tensor mask.
* [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional
masking input. If no mask is provided to this layer, it performs a standard
softmax; however, if a mask tensor is applied (which should be 1 in
positions where the data should be allowed through, and 0 where the data
should be masked), the output will have masked positions set to
approximately zero.
* [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional masking input. If no mask is provided to this layer, it performs a standard softmax; however, if a mask tensor is applied (which should be 1 in positions where the data should be allowed through, and 0 where the data should be masked), the output will have masked positions set to approximately zero.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks.
......@@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Layers package definition."""
from official.nlp.modeling.layers.attention import * # pylint: disable=wildcard-import
# pylint: disable=wildcard-import
from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Classification head layer which is common used with sequence encoders."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
class ClassificationHead(tf.keras.layers.Layer):
"""Pooling head for sentence-level classification tasks."""
def __init__(self,
inner_dim,
num_classes,
cls_token_idx=0,
activation="tanh",
dropout_rate=0.0,
initializer="glorot_uniform",
**kwargs):
"""Initializes the `ClassificationHead`.
Args:
inner_dim: The dimensionality of inner projection layer.
num_classes: Number of output classes.
cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation.
dropout_rate: Dropout probability.
initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super(ClassificationHead, self).__init__(**kwargs)
self.dropout_rate = dropout_rate
self.inner_dim = inner_dim
self.num_classes = num_classes
self.activation = tf_utils.get_activation(activation)
self.initializer = tf.keras.initializers.get(initializer)
self.cls_token_idx = cls_token_idx
self.dense = tf.keras.layers.Dense(
units=inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits")
def call(self, features):
x = features[:, self.cls_token_idx, :] # take <CLS> token.
x = self.dense(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
def get_config(self):
config = {
"dropout_rate": self.dropout_rate,
"num_classes": self.num_classes,
"inner_dim": self.inner_dim,
"activation": tf.keras.activations.serialize(self.activation),
"initializer": tf.keras.initializers.serialize(self.initializer),
}
config.update(super(ClassificationHead, self).get_config())
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
return {self.dense.name: self.dense}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cls_head."""
import tensorflow as tf
from official.nlp.modeling.layers import cls_head
class ClassificationHead(tf.test.TestCase):
def test_layer_invocation(self):
test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
output = test_layer(features)
self.assertAllClose(output, [[0., 0.], [0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"])
def test_layer_serialization(self):
layer = cls_head.ClassificationHead(10, 2)
new_layer = cls_head.ClassificationHead.from_config(layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(layer.get_config(), new_layer.get_config())
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment