Commit c57e975a authored by saberkun's avatar saberkun
Browse files

Merge pull request #10338 from srihari-humbarwadi:readme

PiperOrigin-RevId: 413033276
parents 7fb4f3cd acf4156e
......@@ -113,6 +113,7 @@ python3 train.py \
--experiment=bert/sentence_prediction \
--mode=train_and_eval \
--model_dir=$OUTPUT_DIR \
--config_file=configs/models/bert_en_uncased_base.yaml \
--config_file=configs/experiments/glue_mnli_matched.yaml \
--tfhub_cache_dir=$OUTPUT_DIR/hub_cache \
--tpu=${TPU_NAME} \
......@@ -172,6 +173,7 @@ python3 train.py \
--experiment=bert/squad \
--mode=train_and_eval \
--model_dir=$OUTPUT_DIR \
--config_file=configs/models/bert_en_uncased_base.yaml \
--config_file=configs/experiments/squad_v1.1.yaml \
--tpu=${TPU_NAME} \
--params_override=$PARAMS
......
......@@ -50,6 +50,14 @@ assemble new `tf.keras` layers or models.
feature-based Gaussian process described in ["Random Features for
Large-Scale Kernel Machines"](https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf).
* [ReuseMultiHeadAttention](reuse_attention.py) supports passing
attention scores to be reused and avoid recomputation described in
["Leveraging redundancy in attention with Reuse Transformers"](https://arxiv.org/abs/2110.06821).
* [ReuseTransformer](reuse_transformer.py) supports reusing attention scores
from lower layers in higher layers to avoid recomputing attention scores
described in ["Leveraging redundancy in attention with Reuse Transformers"](https://arxiv.org/abs/2110.06821).
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
......
......@@ -21,7 +21,6 @@ from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.bigbird_attention import BigBirdAttention
from official.nlp.modeling.layers.bigbird_attention import BigBirdMasks
from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.gaussian_process import RandomFeatureGaussianProcess
from official.nlp.modeling.layers.kernel_attention import KernelAttention
......@@ -39,6 +38,8 @@ from official.nlp.modeling.layers.position_embedding import RelativePositionBias
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.relative_attention import MultiHeadRelativeAttention
from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAttention
from official.nlp.modeling.layers.reuse_attention import ReuseMultiHeadAttention
from official.nlp.modeling.layers.reuse_transformer import ReuseTransformer
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.spectral_normalization import *
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based einsum layer."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
@tf.keras.utils.register_keras_serializable(package="Text")
class DenseEinsum(tf.keras.layers.Layer):
"""A densely connected layer that uses `tf.einsum` as the backing computation.
This layer can perform einsum calculations of arbitrary dimensionality.
Args:
output_shape: Positive integer or tuple, dimensionality of the output space.
num_summed_dimensions: The number of dimensions to sum over. Standard 2D
matmul should use 1, 3D matmul should use 2, and so forth.
activation: Activation function to use. If you don't specify anything, no
activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation")..
kernel_constraint: Constraint function applied to the `kernel` weights
matrix.
bias_constraint: Constraint function applied to the bias vector.
Input shape:
N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common
situation would be a 2D input with shape `(batch_size, input_dim)`.
Output shape:
N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D
input with shape `(batch_size, input_dim)`, the output would have shape
`(batch_size, units)`.
"""
@deprecation.deprecated(None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self,
output_shape,
num_summed_dimensions=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(DenseEinsum, self).__init__(**kwargs)
self._output_shape = output_shape if isinstance(
output_shape, (list, tuple)) else (output_shape,)
self._activation = tf.keras.activations.get(activation)
self._use_bias = use_bias
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._num_summed_dimensions = num_summed_dimensions
self._einsum_string = None
def _build_einsum_string(self, free_input_dims, bound_dims, output_dims):
input_str = ""
kernel_str = ""
output_str = ""
letter_offset = 0
for i in range(free_input_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_input_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
return input_str + "," + kernel_str + "->" + output_str
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
input_rank = input_shape.rank
free_input_dims = input_rank - self._num_summed_dimensions
output_dims = len(self._output_shape)
self._einsum_string = self._build_einsum_string(free_input_dims,
self._num_summed_dimensions,
output_dims)
# This is only saved for testing purposes.
self._kernel_shape = (
input_shape[free_input_dims:].concatenate(self._output_shape))
self._kernel = self.add_weight(
"kernel",
shape=self._kernel_shape,
initializer=self._kernel_initializer,
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
if self._use_bias:
self._bias = self.add_weight(
"bias",
shape=self._output_shape,
initializer=self._bias_initializer,
regularizer=self._bias_regularizer,
constraint=self._bias_constraint,
dtype=self.dtype,
trainable=True)
else:
self._bias = None
super(DenseEinsum, self).build(input_shape)
def get_config(self):
config = {
"output_shape":
self._output_shape,
"num_summed_dimensions":
self._num_summed_dimensions,
"activation":
tf.keras.activations.serialize(self._activation),
"use_bias":
self._use_bias,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(DenseEinsum, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
ret = tf.einsum(self._einsum_string, inputs, self._kernel)
if self._use_bias:
ret += self._bias
if self._activation is not None:
ret = self._activation(ret)
return ret
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based einsum layer."""
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import dense_einsum
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class DenseEinsumLayer(keras_parameterized.TestCase):
def test_3D_einsum_with_two_bound_dimensions(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64,), num_summed_dimensions=2)
# Create a 4-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 40, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
def test_3D_einsum_with_one_bound_dimensions(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64, 32), num_summed_dimensions=1)
# Create a 3-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abc,cde->abde")
self.assertEqual(test_layer._kernel_shape, (80, 64, 32))
def test_2D_einsum_with_one_bound_dimensions(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64,), num_summed_dimensions=1)
# Create a 3-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
self.assertEqual(test_layer._kernel_shape, (80, 64))
def test_bias_term_can_be_disabled(self):
# A layer created using the bias should have two weights.
test_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, use_bias=True)
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(2, len(test_layer.get_weights()))
# A layer created without the bias should have only one weight.
test_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, use_bias=False)
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(1, len(test_layer.get_weights()))
def test_activation(self):
# Create a model that does not use an activation.
no_activation_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, activation=None)
input_tensor = tf.keras.Input(shape=(None, 80))
output_tensor = no_activation_layer(input_tensor)
no_activation_model = tf.keras.Model(input_tensor, output_tensor)
# Create a model that uses a softmax activation.
activation_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, activation="softmax")
input_tensor = tf.keras.Input(shape=(None, 80))
output_tensor = activation_layer(input_tensor)
activation_model = tf.keras.Model(input_tensor, output_tensor)
# Make sure the models' weights are identical.
activation_model.set_weights(no_activation_model.get_weights())
# Predict using each model on the same input data. The output should be
# different, since one is using a softmax - even though the models' weights
# are the same.
input_values = 10 * np.random.random_sample((10, 4, 80))
non_activated_data = no_activation_model.predict(input_values)
activated_data = activation_model.predict(input_values)
self.assertNotAllClose(activated_data, non_activated_data)
def test_non_iterable_output_shape(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1)
# Create a 3-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
self.assertEqual(test_layer._kernel_shape, (80, 64))
def test_with_explicit_initializer(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64,),
num_summed_dimensions=2,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 4-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 40, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
if __name__ == "__main__":
tf.test.main()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.nhnet.multi_channel_attention."""
"""Tests for projects.nhnet.multi_channel_attention."""
import numpy as np
import tensorflow as tf
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
import collections
import math
import string
import numpy as np
import tensorflow as tf
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
(1) Query-key dot product:
`(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)`
(2) Combination:
`(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)`
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`,
that attention will be applied to.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
product_notation)
attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation,
target_notation)
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
class ReuseMultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention as described in the paper
"Attention is all you Need" (Vaswani et al., 2017).
If `query`, `key,` `value` are the same, then
this is self-attention. Each timestep in `query` attends to the
corresponding sequence in `key`, and returns a fixed-width vector.
This layer first projects `query`, `key` and `value`. These are
(effectively) a list of tensors of length `num_attention_heads`, where the
corresponding shapes are `(batch_size, <query dimensions>, key_dim)`,
`(batch_size, <key/value dimensions>, key_dim)`,
`(batch_size, <key/value dimensions>, value_dim)`.
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor.
Finally, the result tensor with the last dimension as value_dim can take an
linear projection and return.
Examples:
Performs 1D cross-attention over two sequence inputs with an attention mask.
Returns the additional attention weights over heads.
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
... return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)
Args:
num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of each attention head for value.
dropout: Dropout probability.
reuse_attention: An integer specifying number of heads to reuse.
-1 for all heads.
use_relative_pe: Whether to use relative position bias.
max_sequence_length: Used to set the size of the relative positin encodings.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call arguments:
query: Query `Tensor` of shape `(B, T, dim)`.
value: Value `Tensor` of shape `(B, S, dim)`.
key: Optional key `Tensor` of shape `(B, S, dim)`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions. The boolean mask specifies which query
elements can attend to which key elements, 1 indicates attention and 0
indicates no attention. Broadcasting can happen for the missing batch
dimensions and the head dimension.
return_attention_scores: A boolean to indicate whether the output should
be attention output if True, or (attention_output, attention_scores) if
False. Defaults to False.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
Defaults to either using the training mode of the parent layer/model,
or False (inference) if there is no parent layer.
Returns:
attention_output: The result of the computation, of shape `(B, T, E)`,
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are project to the shape specified by `output_shape`.
attention_scores: [Optional] multi-head attention coeffients over
attention axes.
"""
def __init__(self,
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
reuse_attention=0,
use_relative_pe=False,
pe_max_seq_length=512,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(ReuseMultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
if reuse_attention > self._num_heads or reuse_attention < -1:
raise ValueError("reuse_attention should be between -1 "
"and %d in call to %s." % (self.__class__,
self._num_heads))
if reuse_attention == -1:
reuse_attention = self._num_heads
self._reuse_heads = reuse_attention
self._use_relative_pe = use_relative_pe
self._pe_max_seq_length = pe_max_seq_length
self._use_bias = use_bias
self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
if attention_axes is not None and not isinstance(attention_axes,
collections.abc.Sized):
self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
self._built_from_signature = False
self._query_shape, self._key_shape, self._value_shape = None, None, None
# Use relative PE only if reuse_heads < num_heads.
if self._use_relative_pe and self._reuse_heads < self._num_heads:
# Determine the dtype from global policy.
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
policy = tf.bfloat16
elif policy.name == "mixed_float16":
policy = tf.float16
else:
policy = tf.float32
self._position_embeddings = tf.Variable(
name="relative_position_embeddings",
initial_value=lambda: tf.random.truncated_normal( # pylint: disable=g-long-lambda
[
1, self._num_heads - self._reuse_heads, 2 * self.
_pe_max_seq_length - 1
], mean=0.0, stddev=0.2, dtype=policy),
trainable=True, dtype=policy)
def get_config(self):
config = {
"num_heads": self._num_heads,
"key_dim": self._key_dim,
"value_dim": self._value_dim,
"dropout": self._dropout,
"use_bias": self._use_bias,
"output_shape": self._output_shape,
"attention_axes": self._attention_axes,
"reuse_attention": self._reuse_heads,
"use_relative_pe": self._use_relative_pe,
"pe_max_seq_length": self._pe_max_seq_length,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"query_shape": self._query_shape,
"key_shape": self._key_shape,
"value_shape": self._value_shape,
}
base_config = super(ReuseMultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
# If the layer has a different build() function from the Keras default,
# we need to trigger the customized build to create weights.
query_shape = config.pop("query_shape")
key_shape = config.pop("key_shape")
value_shape = config.pop("value_shape")
layer = cls(**config)
if None in [query_shape, key_shape, value_shape]:
tf.get_logger().warning(
"One of dimensions of the input shape is missing. It should have been"
" memorized when the layer was serialized. "
"%s is created without weights.",
str(cls))
else:
layer._build_from_signature(query_shape, value_shape, key_shape) # pylint: disable=protected-access
return layer
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: Query tensor or TensorShape.
value: Value tensor or TensorShape.
key: Key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
self._query_shape = tf.TensorShape(query.shape)
else:
self._query_shape = tf.TensorShape(query)
if hasattr(value, "shape"):
self._value_shape = tf.TensorShape(value.shape)
else:
self._value_shape = tf.TensorShape(value)
if key is None:
self._key_shape = self._value_shape
elif hasattr(key, "shape"):
self._key_shape = tf.TensorShape(key.shape)
else:
self._key_shape = tf.TensorShape(key)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Any setup work performed only once should happen in an `init_scope`
# to avoid creating symbolic Tensors that will later pollute any eager
# operations.
with tf.init_scope():
free_dims = self._query_shape.rank - 1
if self._reuse_heads < self._num_heads:
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = []
if self._reuse_heads > 0:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_reuse",
**common_kwargs))
if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_new",
**common_kwargs))
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self._build_attention(output_rank)
self._output_dense = []
if self._reuse_heads > 0:
self._output_dense.append(self._make_output_dense(
free_dims, common_kwargs, "attention_output_reuse"))
if self._reuse_heads < self._num_heads:
self._output_dense.append(self._make_output_dense(
free_dims, common_kwargs, "attention_output_new",
self._reuse_heads == 0))
def _make_output_dense(self, free_dims, common_kwargs, name=None,
use_bias=True):
"""Builds the output projection matrix.
Args:
free_dims: Number of free dimensions for einsum equation building.
common_kwargs: Common keyword arguments for einsum layer.
name: Name for the projection layer.
use_bias: Use bias if self._use_bias is true
Returns:
Projection layer.
"""
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else:
output_shape = [self._query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
return tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name,
**common_kwargs)
def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._softmax = tf.keras.layers.Softmax(axis=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _masked_softmax(self, attention_scores, attention_mask=None):
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
if attention_mask is not None:
# The expand dim happens starting from the `num_heads` dimension,
# (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>)
mask_expansion_axes = [-len(self._attention_axes) * 2 - 1]
for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
attention_mask = tf.expand_dims(
attention_mask, axis=mask_expansion_axes)
return self._softmax(attention_scores, attention_mask)
def _compute_relative_position(self, query_seq_length, key_seq_length):
position_zero = self._pe_max_seq_length - 1
# We take the vector position variable and concatenate to form a matrix of
# relative position encodings. i=0 indicates reltaive position is 0.
indices = tf.expand_dims(tf.range(0, -query_seq_length, -1),
-1) + tf.range(key_seq_length) + position_zero
indices = tf.maximum(indices, 0)
indices = tf.minimum(indices, 2*self._pe_max_seq_length-2)
attention_biases = tf.gather(self._position_embeddings, indices, axis=2)
return attention_biases
def _compute_attention(self,
query,
key,
value,
reuse_scores=None,
attention_mask=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
key: Projected key `Tensor` of shape `(B, T, N, key_dim)`.
value: Projected value `Tensor` of shape `(B, T, N, value_dim)`.
reuse_scores: Attention scores from a previous layer if needed.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Partial or no reuse
if self._reuse_heads < self._num_heads:
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
new_scores = tf.einsum(self._dot_product_equation, key, query)
# Add relative position embeddings if required.
if self._use_relative_pe:
new_scores = new_scores + self._compute_relative_position(
tf.shape(query)[1], tf.shape(key)[1])
new_scores = self._masked_softmax(new_scores, attention_mask)
if self._reuse_heads > 0: # Partial reuse
reuse_scores = reuse_scores[:, :self._reuse_heads, :, :]
attention_scores = tf.concat([new_scores, reuse_scores], 1)
else: # No reuse
attention_scores = new_scores
else: # Full reuse
attention_scores = reuse_scores
new_scores = None
# `context_layer` = [B, T, N, H]
attention_output = []
# Partial or full reuse
if self._reuse_heads > 0:
attention_output.append(
tf.einsum(self._combine_equation, self._dropout_layer(
reuse_scores, training=training), value[0]))
# Partial or no reuse
if self._reuse_heads < self._num_heads:
attention_output.append(
tf.einsum(self._combine_equation, self._dropout_layer(
new_scores, training=training), value[-1]))
return attention_output, attention_scores
def call(self,
query,
value,
key=None,
attention_mask=None,
return_attention_scores=False,
training=None,
reuse_attention_scores=None):
if self._reuse_heads > 0 and reuse_attention_scores is None:
raise ValueError("reuse_attention_scores cannot be None when "
"reuse_attention is True or > 0.")
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `value` = [B, S, N, H]
value = [vd(value) for vd in self._value_dense]
if self._reuse_heads < self._num_heads:
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
else:
query, key = None, None
attention_output, attention_scores = self._compute_attention(
query, key, value, reuse_attention_scores, attention_mask, training)
attention_output = [od(attention_output[i]) for i, od in enumerate(
self._output_dense)]
if len(attention_output) == 1:
attention_output = attention_output[0]
else:
attention_output = attention_output[0] + attention_output[1]
if return_attention_scores:
return attention_output, attention_scores
return attention_output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the attention layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import reuse_attention as attention
class ReuseMultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_dim, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12,
key_dim=64,
value_dim=value_dim,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer(query, query, return_attention_scores=True)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
def test_attention_scores_with_values(self):
"""Test attention outputs with coefficients."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(60, 80))
output, coef = test_layer(query, value, return_attention_scores=True)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60])
@parameterized.named_parameters(
("with_bias", True, 0), ("no_bias", False, 0),
("reuse_all_with_bias", True, -1), ("reuse_all_no_bias", False, -1),
("reuse_partial_with_bias", True, 1),
("reuse_partial_no_bias", False, 1))
def test_masked_attention(self, use_bias, reuse_attention):
"""Test with a mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=2, key_dim=2, use_bias=use_bias,
reuse_attention=reuse_attention)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
reuse_attention_scores = tf.keras.Input(shape=(2, 4, 2))
output = test_layer(query=query, value=value, attention_mask=mask_tensor,
reuse_attention_scores=reuse_attention_scores)
# Create a model containing the test layer.
model = tf.keras.Model(
[query, value, mask_tensor, reuse_attention_scores], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
reuse_scores = np.random.random_sample((batch_size, 2, 4, 2))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict(
[from_data, to_data, mask_data, reuse_scores])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict(
[from_data, to_data, null_mask_data, reuse_scores])
# Because one data is masked and one is not, the outputs should not be the
# same.
if reuse_attention == -1:
self.assertAllEqual(masked_output_data, unmasked_output_data)
else:
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor,
reuse_attention_scores=reuse_attention_scores)
model = tf.keras.Model(
[query, value, key, mask_tensor, reuse_attention_scores], output)
masked_output_data = model.predict(
[from_data, to_data, to_data, mask_data, reuse_scores])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data, reuse_scores])
# Because one data is masked and one is not, the outputs should not be the
# same.
if reuse_attention == -1:
self.assertAllEqual(masked_output_data, unmasked_output_data)
else:
self.assertNotAllClose(masked_output_data, unmasked_output_data)
if reuse_attention > 0:
self.assertLen(test_layer._output_dense, 2)
if use_bias:
if reuse_attention == 0:
self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense[0].trainable_variables, 2)
if len(test_layer._output_dense) == 2:
self.assertLen(test_layer._output_dense[1].trainable_variables, 1)
else:
if reuse_attention == 0:
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense[0].trainable_variables, 1)
if len(test_layer._output_dense) == 2:
self.assertLen(test_layer._output_dense[1].trainable_variables, 1)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12,
key_dim=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention_with_scores(self):
"""Test with a mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=2, key_dim=2)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Create a model containing attention scores.
output, scores = test_layer(
query=query, value=value, attention_mask=mask_tensor,
return_attention_scores=True)
model = tf.keras.Model([query, value, mask_tensor], [output, scores])
masked_output_data_score, masked_score = model.predict(
[from_data, to_data, mask_data])
unmasked_output_data_score, unmasked_score = model.predict(
[from_data, to_data, null_mask_data])
self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
self.assertAllClose(masked_output_data, masked_output_data_score)
self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
self.assertNotAllClose(masked_score, unmasked_score)
@parameterized.named_parameters(
("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
(2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
(2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=2, key_dim=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
value_shape = [batch_size] + v_dims + [hidden_size]
mask_shape = [batch_size] + mask_dims
query = 10 * np.random.random_sample(query_shape)
value = 10 * np.random.random_sample(value_shape)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
# Because one data is masked and one is not, the outputs should not be the
# same.
query_tensor = tf.keras.Input(query_shape[1:], name="query")
value_tensor = tf.keras.Input(value_shape[1:], name="value")
mask_tensor = tf.keras.Input(mask_shape[1:], name="mask")
output = test_layer(query=query_tensor, value=value_tensor,
attention_mask=mask_tensor)
model = tf.keras.Model([query_tensor, value_tensor, mask_tensor], output)
self.assertNotAllClose(
model.predict([query, value, mask_data]),
model.predict([query, value, null_mask_data]))
def test_dropout(self):
test_layer = attention.ReuseMultiHeadAttention(
num_heads=2, key_dim=2, dropout=0.5)
# Generate data for the input (non-mask) tensors.
from_data = tf.keras.backend.ones(shape=(32, 4, 8))
to_data = tf.keras.backend.ones(shape=(32, 2, 8))
train_out = test_layer(from_data, to_data, None, None, None, True)
test_out = test_layer(from_data, to_data, None, None, None, False)
# Output should be close when not in training mode,
# and should not be close when enabling dropout in training mode.
self.assertNotAllClose(
tf.keras.backend.eval(train_out),
tf.keras.backend.eval(test_out))
def test_non_masked_self_attention_with_reuse(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12, key_dim=64, reuse_attention=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
reuse_scores = tf.keras.Input(shape=(12, 40, 40))
output = test_layer(query, query, reuse_attention_scores=reuse_scores)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
("no_reuse_with_pe_max_seq_length_20", False, 20),
("reuse_all_with_pe_max_seq_length_20", True, 20),
("reuse_partial_with_pe_max_seq_length_20", 5, 20),
("no_reuse_with_pe_max_seq_length_40", False, 40),
("reuse_all_with_pe_max_seq_length_40", True, 40),
("reuse_partial_with_pe_max_seq_length_40", 5, 40))
def test_non_masked_self_attention_with_relative_pe(self, reuse_attention,
pe_max_seq_length):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.ReuseMultiHeadAttention(
num_heads=12, key_dim=64, reuse_attention=reuse_attention,
use_relative_pe=True, pe_max_seq_length=pe_max_seq_length)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
reuse_scores = tf.keras.Input(shape=(12, 40, 40))
output = test_layer(query, query, reuse_attention_scores=reuse_scores)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
query = tf.keras.Input(shape=(30, 80))
reuse_scores = tf.keras.Input(shape=(12, 30, 30))
output = test_layer(query, query, reuse_attention_scores=reuse_scores)
self.assertEqual(output.shape.as_list(), [None, 30, 80])
query = tf.keras.Input(shape=(30, 80))
key = tf.keras.Input(shape=(20, 80))
reuse_scores = tf.keras.Input(shape=(12, 30, 20))
output = test_layer(query, key, reuse_attention_scores=reuse_scores)
self.assertEqual(output.shape.as_list(), [None, 30, 80])
query = tf.keras.Input(shape=(50, 80))
key = tf.keras.Input(shape=(60, 80))
reuse_scores = tf.keras.Input(shape=(12, 50, 60))
output = test_layer(query, key, reuse_attention_scores=reuse_scores)
self.assertEqual(output.shape.as_list(), [None, 50, 80])
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.nlp.modeling.layers import reuse_attention as attention
class ReuseTransformer(tf.keras.layers.Layer):
"""Transformer layer.
This layer implements the ReuseTransformer Encoder from
"Leveraging redundancy in attention with Reuse Transformers".
(https://arxiv.org/abs/2110.06821)
"""
def __init__(self,
num_attention_heads,
inner_dim,
inner_activation,
head_size=None,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
reuse_attention=0,
use_relative_pe=False,
pe_max_seq_length=512,
layer_idx=None,
max_reuse_layer_idx=None,
**kwargs):
"""Initializes `ReuseTransformer`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
head_size: Projection size of heads.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
reuse_attention: reuse_attention: An integer specifying number of heads
to reuse. -1 for all heads.
use_relative_pe: whether to use relative position bias.
pe_max_seq_length: used to set the size of the relative positin encodings.
layer_idx: the idx of this layer.
max_reuse_layer_idx: layer idx (if passed) greater than this value will
not reuse attention scores from previous layers.
**kwargs: keyword arguments.
"""
super().__init__(**kwargs)
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._head_size = head_size
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout
self._reuse_attention = reuse_attention
self._use_relative_pe = use_relative_pe
self._pe_max_seq_length = pe_max_seq_length
self._layer_idx = layer_idx
self._max_reuse_layer_idx = max_reuse_layer_idx
# Overwrite for the first layer and layers greater than max_reuse_layer_idx.
if self._layer_idx is not None and (
self._layer_idx == 0 or (self._max_reuse_layer_idx is not None and
self._max_reuse_layer_idx < self._layer_idx)):
self._reuse_attention = 0
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if self._head_size is None:
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
else:
self._attention_head_size = self._head_size
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = attention.ReuseMultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe,
pe_max_seq_length=self._pe_max_seq_length,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(
rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(ReuseTransformer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"head_size":
self._head_size,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"reuse_attention":
self._reuse_attention,
"use_relative_pe": self._use_relative_pe,
"pe_max_seq_length": self._pe_max_seq_length,
"max_reuse_layer_idx": self._max_reuse_layer_idx,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
}
base_config = super(ReuseTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `attention mask`, `attention scores`] to have
additional attention scores for reuse computation. If `attention scores`
is None, the reuse_attention flag will be ignored.
Returns:
An output tensor with the same dimensions as input/query tensor.
Attention scores if return_attention_scores is true.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
reuse_attention_scores = None
elif len(inputs) == 3:
input_tensor, attention_mask, reuse_attention_scores = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, attention_mask, reuse_attention_scores = (inputs, None,
None)
key_value = None
if self._reuse_attention != 0 and reuse_attention_scores is None:
raise ValueError(
"reuse_attention_scores cannot be None when reuse_attention != 0.")
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
if reuse_attention_scores is not None:
reuse_attention_scores = reuse_attention_scores[:, :,
0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask,
reuse_attention_scores=reuse_attention_scores,
return_attention_scores=True)
attention_output, attention_scores = attention_output
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output, attention_scores
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output, attention_scores
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based transformer block layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import reuse_transformer
@parameterized.named_parameters(
('base', reuse_transformer.ReuseTransformer))
class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
def tearDown(self):
super(ReuseTransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor, _ = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor, _ = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor, _ = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_output_range_with_relative_pe(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu',
use_relative_pe=True)
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor, _ = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1,
use_relative_pe=True)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
inner_activation='relu', norm_first=True)
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
output_tensor, _ = test_layer(input_data)
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1,
norm_first=True)
_ = new_layer(input_data)
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer(input_data)
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_output_range_with_pre_norm(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
inner_activation='relu', norm_first=True)
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor, _ = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1,
norm_first=True)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output, _ = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor, _ = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
encoder_block = reuse_transformer.ReuseTransformer(
num_attention_heads=num_attention_heads,
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_mask]
output, _ = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
encoder_block = reuse_transformer.ReuseTransformer(
num_attention_heads=num_attention_heads,
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
encoder_block_config = encoder_block.get_config()
new_encoder_block = reuse_transformer.ReuseTransformer.from_config(
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
@parameterized.parameters({'attention_axes': None}, {'attention_axes': [1]},
{'attention_axes': [2]}, {'attention_axes': [1, 2]})
def test_several_attention_axes(self, attention_axes):
test_layer = reuse_transformer.ReuseTransformer(
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
num_attention_heads=10,
attention_axes=attention_axes)
num_rows = 21
num_cols = 13
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(num_rows, num_cols, width))
output_tensor, _ = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
@parameterized.named_parameters(
('plain', False, False, False),
('plain_returnscore', False, True, False),
('plain_with_relative_pe', False, False, True),
('reuse_all', True, False, False),
('reuse_all_returnscore', True, True, False),
('reuse_all_with_relative_pe', True, False, True),
('reuse_5', 5, False, False),
('reuse_5_returnscore', 5, True, False),
('reuse_5_with_relative_pe', 5, False, True),)
def test_layer_invocation_with_mask(self, reuse_attention,
return_attention_scores, use_relative_pe):
test_layer = reuse_transformer.ReuseTransformer(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
reuse_attention=reuse_attention,
use_relative_pe=use_relative_pe)
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
return_scores_tensor = tf.keras.Input(shape=(1,))
reuse_attention_scores = tf.keras.Input(
shape=(10, sequence_length, sequence_length))
output_tensor, _ = test_layer(
[data_tensor, mask_tensor, reuse_attention_scores])
# Create a model from the test layer.
model = tf.keras.Model(
([data_tensor, mask_tensor, reuse_attention_scores],
return_scores_tensor), output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
reuse_scores = np.random.rand(
batch_size, 10, sequence_length, sequence_length)
_ = model.predict([input_data, mask_data, reuse_scores],
return_attention_scores)
@parameterized.named_parameters(
('without_relative_pe_with_pe_max_seq_length_10', False, 10),
('with_relative_pe_with_pe_max_seq_length_10', True, 10),
('without_relative_pe_with_pe_max_seq_length_100', False, 100),
('with_relative_pe_with_pe_max_seq_length_100', True, 100))
def test_layer_invocation_with_float16_with_relative_pe(
self, use_relative_pe, pe_max_seq_length):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = reuse_transformer.ReuseTransformer(
num_attention_heads=10, inner_dim=2048, inner_activation='relu',
use_relative_pe=use_relative_pe, pe_max_seq_length=pe_max_seq_length)
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
if __name__ == '__main__':
tf.test.main()
......@@ -26,7 +26,6 @@ from official.nlp.modeling.ops import beam_search
EOS_ID = 1
@tf.keras.utils.register_keras_serializable(package="Text")
class Seq2SeqTransformer(tf.keras.Model):
"""Transformer model with Keras.
......@@ -261,11 +260,9 @@ class Seq2SeqTransformer(tf.keras.Model):
return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
decoder_inputs = self.embedding_lookup(targets)
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
......@@ -326,12 +323,7 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
# decoder_input = self.embedding_softmax_layer(decoder_input)
source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast(
tf.not_equal(source_decoder_input, 0), decoder_input.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i]
if self._padded_decode:
# indexing does not work on TPU.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer-based BERT encoder network with dense features as inputs."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Callable, Optional, Union
from absl import logging
import tensorflow as tf
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class BertDenseEncoder(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network with dense features.
This network is the same as the BertEncoder except it also concats dense
features with the embeddings.
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
}
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_inputs=tf.keras.Input(
shape=(None, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
)
def call(self, inputs):
word_embeddings = None
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
word_embeddings = inputs.get('input_word_embeddings', None)
dense_inputs = inputs.get('dense_inputs')
dense_mask = inputs.get('dense_mask')
dense_type_ids = inputs.get('dense_type_ids')
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
# Concat the dense embeddings at sequence end.
combined_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1)
combined_type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
combined_mask = tf.concat([mask, dense_mask], axis=1)
# absolute position embeddings.
position_embeddings = self._position_embedding_layer(combined_embeddings)
type_embeddings = self._type_embedding_layer(combined_type_ids)
embeddings = combined_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, combined_mask)
encoder_outputs = []
x = embeddings
for layer in self._transformer_layers:
x = layer([x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transformer-based bert encoder network with dense features as inputs."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import bert_dense_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class BertDenseEncoderTest(keras_parameterized.TestCase):
def tearDown(self):
super(BertDenseEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
def test_dict_outputs_network_creation(self):
hidden_size = 32
sequence_length = 21
dense_sequence_length = 20
# Create a small dense BertDenseEncoder for testing.
kwargs = {}
test_network = bert_dense_encoder.BertDenseEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
**kwargs)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, 3)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [
None, sequence_length + dense_sequence_length, hidden_size
]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_dict_outputs_all_encoder_outputs_network_creation(self):
hidden_size = 32
sequence_length = 21
dense_sequence_length = 20
# Create a small BertEncoder for testing.
test_network = bert_dense_encoder.BertDenseEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [
None, sequence_length + dense_sequence_length, hidden_size
]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_dict_outputs_network_creation_with_float16_dtype(self):
hidden_size = 32
sequence_length = 21
dense_sequence_length = 20
tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small BertEncoder for testing.
test_network = bert_dense_encoder.BertDenseEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [
None, sequence_length + dense_sequence_length, hidden_size
]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# If float_dtype is set to float16, the data output is float32 (from a layer
# norm) and pool output should be float16.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float16, pooled.dtype)
@parameterized.named_parameters(
("all_sequence_encoder_v2", bert_dense_encoder.BertDenseEncoder, None,
41),
("output_range_encoder_v2", bert_dense_encoder.BertDenseEncoder, 1, 1),
)
def test_dict_outputs_network_invocation(
self, encoder_cls, output_range, out_seq_len):
hidden_size = 32
sequence_length = 21
dense_sequence_length = 20
vocab_size = 57
num_types = 7
# Create a small BertEncoder for testing.
test_network = encoder_cls(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=output_range,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
# Create a model based off of this network:
model = tf.keras.Model(
[word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids],
[data, pooled])
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
batch_size = 3
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
dense_input_data = np.random.rand(batch_size, dense_sequence_length,
hidden_size)
dense_mask_data = np.random.randint(
2, size=(batch_size, dense_sequence_length))
dense_type_ids_data = np.random.randint(
num_types, size=(batch_size, dense_sequence_length))
outputs = model.predict([
word_id_data, mask_data, type_id_data, dense_input_data,
dense_mask_data, dense_type_ids_data
])
self.assertEqual(outputs[0].shape[1], out_seq_len)
# Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length = 128
test_network = encoder_cls(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
dict_outputs=True)
dict_outputs = test_network(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model(
[word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids],
[data, pooled])
outputs = model.predict([
word_id_data, mask_data, type_id_data, dense_input_data,
dense_mask_data, dense_type_ids_data
])
self.assertEqual(outputs[0].shape[1],
sequence_length + dense_sequence_length)
# Creates a BertEncoder with embedding_width != hidden_size
embedding_width = 16
test_network = bert_dense_encoder.BertDenseEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
embedding_width=embedding_width,
dict_outputs=True)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, embedding_width), dtype=tf.float32)
dense_input_data = np.zeros(
(batch_size, dense_sequence_length, embedding_width), dtype=float)
dict_outputs = test_network(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model(
[word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids],
[data, pooled])
outputs = model.predict([
word_id_data, mask_data, type_id_data, dense_input_data,
dense_mask_data, dense_type_ids_data
])
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_embeddings_as_inputs(self):
hidden_size = 32
sequence_length = 21
dense_sequence_length = 20
# Create a small BertEncoder for testing.
test_network = bert_dense_encoder.BertDenseEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
test_network.build(
dict(
input_word_ids=word_ids,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
embeddings = test_network.get_embedding_layer()(word_ids)
# Calls with the embeddings.
dict_outputs = test_network(
dict(
input_word_embeddings=embeddings,
input_mask=mask,
input_type_ids=type_ids,
dense_inputs=dense_inputs,
dense_mask=dense_mask,
dense_type_ids=dense_type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [
None, sequence_length + dense_sequence_length, hidden_size
]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
if __name__ == "__main__":
tf.test.main()
......@@ -102,6 +102,9 @@ class EncoderScaffold(tf.keras.Model):
dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
feed_layer_idx: whether the scaffold should feed layer index to hidden_cls.
recursive: whether to pass the second return of the hidden layer as the last
element among the inputs. None will be passed as the initial state.
"""
def __init__(self,
......@@ -120,6 +123,8 @@ class EncoderScaffold(tf.keras.Model):
return_all_layer_outputs=False,
dict_outputs=False,
layer_idx_as_attention_seed=False,
feed_layer_idx=False,
recursive=False,
**kwargs):
if embedding_cls:
......@@ -201,6 +206,8 @@ class EncoderScaffold(tf.keras.Model):
'contain classes or instances with size specified by '
'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls),
num_hidden_instances)
# Consider supporting customized init states.
recursive_states = None
for i in range(num_hidden_instances):
if isinstance(hidden_cls, list):
cur_hidden_cls = hidden_cls[i]
......@@ -211,10 +218,15 @@ class EncoderScaffold(tf.keras.Model):
layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i
if feed_layer_idx:
hidden_cfg['layer_idx'] = i
layer = cur_hidden_cls(**hidden_cfg)
else:
layer = cur_hidden_cls
data = layer([data, attention_mask])
if recursive:
data, recursive_states = layer([data, attention_mask, recursive_states])
else:
data = layer([data, attention_mask])
layer_output_data.append(data)
hidden_layers.append(layer)
......
......@@ -69,6 +69,10 @@ class BigBirdEncoder(tf.keras.Model):
embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
block_size: int. A BigBird Attention parameter: size of block in from/to
sequences.
num_rand_blocks: int. A BigBird Attention parameter: number of random chunks
per row.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
......
......@@ -12,3 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Models NLP Tasks."""
# pylint: disable=g-multiple-import
from official.nlp.tasks.electra_task import ElectraPretrainConfig, ElectraPretrainTask
from official.nlp.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from official.nlp.tasks.question_answering import QuestionAnsweringConfig, QuestionAnsweringTask
from official.nlp.tasks.sentence_prediction import SentencePredictionConfig, SentencePredictionTask
from official.nlp.tasks.tagging import TaggingConfig, TaggingTask
from official.nlp.tasks.translation import TranslationConfig, TranslationTask
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dual encoder (retrieval) task."""
from typing import Mapping, Tuple
# Import libraries
from absl import logging
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A dual encoder (retrieval) configuration."""
# Normalize input embeddings if set to True.
normalize: bool = True
# Maximum input sequence length.
max_sequence_length: int = 64
# Parameters for training a dual encoder model with additive margin, see
# https://www.ijcai.org/Proceedings/2019/0746.pdf for more details.
logit_scale: float = 1
logit_margin: float = 0
bidirectional: bool = False
# Defining k for calculating metrics recall@k.
eval_top_k: Tuple[int, ...] = (1, 3, 10)
encoder: encoders.EncoderConfig = (
encoders.EncoderConfig())
@dataclasses.dataclass
class DualEncoderConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
# Defines the concrete model config at instantiation time.
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@task_factory.register_task_cls(DualEncoderConfig)
class DualEncoderTask(base_task.Task):
"""Task object for dual encoder."""
def build_model(self):
"""Interface to build model. Refer to base_task.Task.build_model."""
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
# Currently, we only supports bert-style dual encoder.
return models.DualEncoder(
network=encoder_network,
max_seq_length=self.task_config.model.max_sequence_length,
normalize=self.task_config.model.normalize,
logit_scale=self.task_config.model.logit_scale,
logit_margin=self.task_config.model.logit_margin,
output='logits')
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Interface to compute losses. Refer to base_task.Task.build_losses."""
del labels
left_logits = model_outputs['left_logits']
right_logits = model_outputs['right_logits']
batch_size = tf_utils.get_shape_list(left_logits, name='batch_size')[0]
ranking_labels = tf.range(batch_size)
loss = tf_utils.safe_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=ranking_labels,
logits=left_logits))
if self.task_config.model.bidirectional:
right_rank_loss = tf_utils.safe_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=ranking_labels,
logits=right_logits))
loss += right_rank_loss
return tf.reduce_mean(loss)
def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path != 'dummy':
return data_loader_factory.get_data_loader(params).load(input_context)
def dummy_data(_):
dummy_ids = tf.zeros((10, params.seq_length), dtype=tf.int32)
x = dict(
left_word_ids=dummy_ids,
left_mask=dummy_ids,
left_type_ids=dummy_ids,
right_word_ids=dummy_ids,
right_mask=dummy_ids,
right_type_ids=dummy_ids)
return x
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def build_metrics(self, training=None):
del training
metrics = [tf.keras.metrics.Mean(name='batch_size_per_core')]
for k in self.task_config.model.eval_top_k:
metrics.append(tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name=f'left_recall_at_{k}'))
if self.task_config.model.bidirectional:
metrics.append(tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name=f'right_recall_at_{k}'))
return metrics
def process_metrics(self, metrics, labels, model_outputs):
del labels
metrics = dict([(metric.name, metric) for metric in metrics])
left_logits = model_outputs['left_logits']
right_logits = model_outputs['right_logits']
batch_size = tf_utils.get_shape_list(
left_logits, name='sequence_output_tensor')[0]
ranking_labels = tf.range(batch_size)
for k in self.task_config.model.eval_top_k:
metrics[f'left_recall_at_{k}'].update_state(ranking_labels, left_logits)
if self.task_config.model.bidirectional:
metrics[f'right_recall_at_{k}'].update_state(ranking_labels,
right_logits)
metrics['batch_size_per_core'].update_state(batch_size)
def validation_step(self,
inputs,
model: tf.keras.Model,
metrics=None) -> Mapping[str, tf.Tensor]:
outputs = model(inputs)
loss = self.build_losses(
labels=None, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, None, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, None, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'],
}
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.tasks.sentence_prediction."""
import functools
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import dual_encoder_dataloader
from official.nlp.tasks import dual_encoder
from official.nlp.tasks import masked_lm
from official.nlp.tools import export_tfhub_lib
class DualEncoderTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(DualEncoderTaskTest, self).setUp()
self._train_data_config = (
dual_encoder_dataloader.DualEncoderDataConfig(
input_path="dummy", seq_length=32))
def get_model_config(self):
return dual_encoder.ModelConfig(
max_sequence_length=32,
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
def _run_task(self, config):
task = dual_encoder.DualEncoderTask(config)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = strategy.distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
dataset.batch(10)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
def test_task(self):
config = dual_encoder.DualEncoderConfig(
init_checkpoint=self.get_temp_dir(),
model=self.get_model_config(),
train_data=self._train_data_config)
task = dual_encoder.DualEncoderTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=4)
encoder = export_tfhub_lib.get_bert_encoder(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(encoder=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
export_path = os.path.join(self.get_temp_dir(), "hub")
export_tfhub_lib.export_model(
export_path,
bert_config=bert_config,
encoder_config=None,
model_checkpoint_path=model_checkpoint_path,
vocab_file=vocab_file,
do_lower_case=True,
with_mlm=False)
return export_path
def test_task_with_hub(self):
hub_module_url = self._export_bert_tfhub()
config = dual_encoder.DualEncoderConfig(
hub_module_url=hub_module_url,
model=self.get_model_config(),
train_data=self._train_data_config)
self._run_task(config)
if __name__ == "__main__":
tf.test.main()
......@@ -16,7 +16,6 @@
import math
import tensorflow as tf
from official.nlp.modeling import layers
class Attention(tf.keras.layers.Layer):
......@@ -51,28 +50,31 @@ class Attention(tf.keras.layers.Layer):
attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
self.hidden_size)
self.query_dense_layer = layers.DenseEinsum(
output_shape=(self.num_heads, size_per_head),
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
use_bias=False,
bias_axes=None,
name="query")
self.key_dense_layer = layers.DenseEinsum(
output_shape=(self.num_heads, size_per_head),
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
use_bias=False,
bias_axes=None,
name="key")
self.value_dense_layer = layers.DenseEinsum(
output_shape=(self.num_heads, size_per_head),
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
use_bias=False,
bias_axes=None,
name="value")
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = layers.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTNH,NHE->BTE",
output_shape=(None, self.hidden_size),
kernel_initializer=output_initializer,
use_bias=False,
bias_axes=None,
name="output_transform")
super(Attention, self).build(input_shape)
......
......@@ -24,6 +24,7 @@ import unicodedata
from absl import app
from absl import flags
from absl import logging
import six
from six.moves import range
import tensorflow as tf
......@@ -109,11 +110,11 @@ def bleu_on_list(ref_lines, hyp_lines, case_sensitive=False):
def main(unused_argv):
if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
tf.logging.info("Case-insensitive results: %f" % score)
logging.info("Case-insensitive results: %f", score)
if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
tf.logging.info("Case-sensitive results: %f" % score)
logging.info("Case-sensitive results: %f", score)
def define_compute_bleu_flags():
......@@ -142,7 +143,6 @@ def define_compute_bleu_flags():
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_compute_bleu_flags()
FLAGS = flags.FLAGS
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment