Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,53 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras layer that creates a self-attention mask."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""Keras layer that creates a self-attention mask."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.keras_nlp import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer):
"""Create 3D attention mask from a 2D tensor mask.
class SelfAttentionMask(layers.SelfAttentionMask):
"""Creates 3D attention mask from a 2D tensor mask.
**Warning: Please use the `keras_nlp.layers.SelfAttentionMask`.**
inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
`(batch_size, from_seq_length, ...)`.
inputs[1]: to_mask: int32 Tensor of shape `(batch_size, to_seq_length)`.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
Float Tensor of shape `(batch_size, from_seq_length, to_seq_length)`.
"""
def call(self, inputs):
from_tensor = inputs[0]
to_mask = inputs[1]
from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=from_tensor.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
if isinstance(inputs, list):
return super().call(inputs[0], inputs[1])
else:
return super().call(inputs)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers.
## References:
[1] Yuichi Yoshida, Takeru Miyato. Spectral Norm Regularization for Improving
the Generalizability of Deep Learning.
_arXiv preprint arXiv:1705.10941_, 2017. https://arxiv.org/abs/1705.10941
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida.
Spectral normalization for generative adversarial networks.
In _International Conference on Learning Representations_, 2018.
[3] Henry Gouk, Eibe Frank, Bernhard Pfahringer, Michael Cree.
Regularisation of neural networks by enforcing lipschitz continuity.
_arXiv preprint arXiv:1804.04368_, 2018. https://arxiv.org/abs/1804.04368
"""
import numpy as np
import tensorflow as tf
class SpectralNormalization(tf.keras.layers.Wrapper):
"""Implements spectral normalization for Dense layer."""
def __init__(self,
layer,
iteration=1,
norm_multiplier=0.95,
training=True,
aggregation=tf.VariableAggregation.MEAN,
inhere_layer_name=False,
**kwargs):
"""Initializer.
Args:
layer: (tf.keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
inhere_layer_name: (bool) Whether to inhere the name of the input layer.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self.iteration = iteration
self.do_power_iteration = training
self.aggregation = aggregation
self.norm_multiplier = norm_multiplier
# Set layer name.
wrapper_name = kwargs.pop('name', None)
if inhere_layer_name:
wrapper_name = layer.name
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError('`layer` must be a `tf.keras.layer.Layer`. '
'Observed `{}`'.format(layer))
super(SpectralNormalization, self).__init__(
layer, name=wrapper_name, **kwargs)
def build(self, input_shape):
super(SpectralNormalization, self).build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer,
trainable=False,
name='v',
dtype=self.dtype,
aggregation=self.aggregation)
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer,
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
self.update_weights()
def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training
u_update_op, v_update_op, w_update_op = self.update_weights(
training=training)
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
return output
def update_weights(self, *, training=True):
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
u_hat = self.u
v_hat = self.v
if training:
for _ in range(self.iteration):
v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w_reshaped)))
u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w_reshaped))
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
# Convert sigma from a 1x1 matrix to a scalar.
sigma = tf.reshape(sigma, [])
u_update_op = self.u.assign(u_hat)
v_update_op = self.v.assign(v_hat)
# Bound spectral norm to be not larger than self.norm_multiplier.
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda
(self.norm_multiplier / sigma) * self.w, lambda: self.w)
w_update_op = self.layer.kernel.assign(w_norm)
return u_update_op, v_update_op, w_update_op
def restore_weights(self):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return self.layer.kernel.assign(self.w)
class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
"""Implements spectral normalization for Conv2D layer based on [3]."""
def __init__(self,
layer,
iteration=1,
norm_multiplier=0.95,
training=True,
aggregation=tf.VariableAggregation.MEAN,
legacy_mode=False,
**kwargs):
"""Initializer.
Args:
layer: (tf.keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
legacy_mode: (bool) Whether to use the legacy implementation where the
dimension of the u and v vectors are set to the batch size. It should
not be enabled unless for backward compatibility reasons.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self.iteration = iteration
self.do_power_iteration = training
self.aggregation = aggregation
self.norm_multiplier = norm_multiplier
self.legacy_mode = legacy_mode
# Set layer attributes.
layer._name += '_spec_norm'
if not isinstance(layer, tf.keras.layers.Conv2D):
raise ValueError(
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
.format(input=layer))
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
def build(self, input_shape):
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
# Shape (kernel_size_1, kernel_size_2, in_channel, out_channel).
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.strides = self.layer.strides
# Set the dimensions of u and v vectors.
batch_size = input_shape[0]
uv_dim = batch_size if self.legacy_mode else 1
# Resolve shapes.
in_height = input_shape[1]
in_width = input_shape[2]
in_channel = self.w_shape[2]
out_height = in_height // self.strides[0]
out_width = in_width // self.strides[1]
out_channel = self.w_shape[3]
self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=self.in_shape,
initializer=self.uv_initializer,
trainable=False,
name='v',
dtype=self.dtype,
aggregation=self.aggregation)
self.u = self.add_weight(
shape=self.out_shape,
initializer=self.uv_initializer,
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
super(SpectralNormalizationConv2D, self).build()
def call(self, inputs):
u_update_op, v_update_op, w_update_op = self.update_weights()
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
return output
def update_weights(self):
"""Computes power iteration for convolutional filters based on [3]."""
# Initialize u, v vectors.
u_hat = self.u
v_hat = self.v
if self.do_power_iteration:
for _ in range(self.iteration):
# Updates v.
v_ = tf.nn.conv2d_transpose(
u_hat,
self.w,
output_shape=self.in_shape,
strides=self.strides,
padding='SAME')
v_hat = tf.nn.l2_normalize(tf.reshape(v_, [1, -1]))
v_hat = tf.reshape(v_hat, v_.shape)
# Updates u.
u_ = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME')
u_hat = tf.nn.l2_normalize(tf.reshape(u_, [1, -1]))
u_hat = tf.reshape(u_hat, u_.shape)
v_w_hat = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME')
sigma = tf.matmul(tf.reshape(v_w_hat, [1, -1]), tf.reshape(u_hat, [-1, 1]))
# Convert sigma from a 1x1 matrix to a scalar.
sigma = tf.reshape(sigma, [])
u_update_op = self.u.assign(u_hat)
v_update_op = self.v.assign(v_hat)
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda
(self.norm_multiplier / sigma) * self.w, lambda: self.w)
w_update_op = self.layer.kernel.assign(w_norm)
return u_update_op, v_update_op, w_update_op
def restore_weights(self):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return self.layer.kernel.assign(self.w)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for normalization layers.
## References:
[1] Hanie Sedghi, Vineet Gupta, Philip M. Long.
The Singular Values of Convolutional Layers.
In _International Conference on Learning Representations_, 2019.
"""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import spectral_normalization
DenseLayer = tf.keras.layers.Dense(10)
Conv2DLayer = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='valid')
def _compute_spectral_norm(weight):
if weight.ndim > 2:
# Computes Conv2D via FFT transform as in [1].
weight = np.fft.fft2(weight, weight.shape[1:3], axes=[0, 1])
return np.max(np.linalg.svd(weight, compute_uv=False))
class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(NormalizationTest, self).setUp()
self.num_iterations = 1000
self.norm_multiplier = 0.95
@parameterized.named_parameters(
('Dense',
(None, 10), DenseLayer, spectral_normalization.SpectralNormalization),
('Conv2D', (None, 32, 32, 3), Conv2DLayer,
spectral_normalization.SpectralNormalizationConv2D))
def test_spec_norm_magnitude(self, input_shape, layer, norm_wrapper):
"""Tests if the weights spectral norm converges to norm_multiplier."""
layer.build(input_shape)
sn_layer = norm_wrapper(
layer,
iteration=self.num_iterations,
norm_multiplier=self.norm_multiplier)
# Perform normalization.
sn_layer.build(input_shape)
sn_layer.update_weights()
normalized_kernel = sn_layer.layer.kernel.numpy()
spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier
self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
# is the norm multiplier.
new_input_shape = (16,) + input_shape[1:]
new_input = tf.random.uniform(new_input_shape)
delta_vec = tf.random.uniform(new_input_shape)
output1 = sn_layer(new_input)
output2 = sn_layer(new_input + delta_vec)
delta_input = tf.norm(tf.reshape(delta_vec, (-1,))).numpy()
delta_output = tf.norm(tf.reshape(output2 - output1, (-1,))).numpy()
self.assertLessEqual(delta_output, self.norm_multiplier * delta_input)
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
import math
......@@ -20,14 +20,12 @@ import string
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
_CHR_IDX = string.ascii_lowercase
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TalkingHeadsAttention(attention.MultiHeadAttention):
class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
"""Implements Talking-Heads Attention.
This is an implementation of Talking-Heads Attention based on the paper
......@@ -35,12 +33,12 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
multi-head attention by including linearprojections across the attention-heads
dimension, immediately before and after the softmax operation.
See the base class `MultiHeadAttention` for more details.
See the base class `tf.keras.layers.MultiHeadAttention` for more details.
Arguments:
Args:
num_heads: Number of attention heads.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
key_dim: Size of each attention head for query and key.
value_dim: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
......@@ -65,7 +63,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
that will be applied on attention scores before and after softmax.
Args:
qkv_rank: the rank of query, key, value tensors after projection.
qkv_rank: The rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
......@@ -107,18 +105,21 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
attention_mask=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
......@@ -129,7 +130,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
1.0 / math.sqrt(float(self._key_dim)))
# Apply linear projection before softmax
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
......@@ -145,7 +146,8 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the attention layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for the attention layer."""
from absl.testing import parameterized
import numpy as np
......@@ -36,35 +32,36 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
def test_non_masked_attention(self, value_dim, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12,
key_size=64,
value_size=value_size,
key_dim=64,
value_dim=value_dim,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64)
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64, return_attention_scores=True)
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query=query, value=query,
return_attention_scores=True)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -72,13 +69,13 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_masked_attention(self, use_bias):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=2, use_bias=use_bias)
num_heads=12, key_dim=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -102,7 +99,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -123,11 +121,11 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
"""Test with a specified initializer."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12,
key_size=64,
key_dim=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -137,7 +135,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=2, attention_axes=attention_axes)
num_heads=12, key_dim=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
......@@ -149,11 +147,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras Layers for BERT-specific preprocessing."""
from typing import Any, Dict, List, Optional, Union
from absl import logging
import tensorflow as tf
try:
import tensorflow_text as text # pylint: disable=g-import-not-at-top
except ImportError:
text = None
except tf.errors.NotFoundError as e:
logging.warn("Encountered error when importing tensorflow_text: %s", e)
text = None
def _check_if_tf_text_installed():
if text is None:
raise ImportError("import tensorflow_text failed, please install "
"'tensorflow-text-nightly'.")
def _iterative_vectorized_fair_share(capacity: tf.Tensor,
limit: Union[int, tf.Tensor]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit = limit // capacity.shape[0]
limit_mask = tf.ones(capacity.shape, dtype=tf.int64) * per_seg_limit
lower_bound = tf.minimum(capacity, limit_mask)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum = limit - tf.math.reduce_sum(lower_bound, axis=0)
remaining_cap_mat = capacity - lower_bound
new_cap = lower_bound + remaining_cap_mat * tf.cast(
tf.math.reduce_sum(remaining_cap_mat, axis=0) <= remaining_cap_sum,
tf.int64)
# Process iteratively. This step is O(#segments), see analysis below.
while True:
remaining_limit = limit - tf.math.reduce_sum(new_cap, axis=0)
remaining_cap = capacity - new_cap
masked_remaining_slots = tf.cast(remaining_cap > 0, tf.int64)
remaining_cap_col_slots = tf.reduce_sum(masked_remaining_slots, axis=0)
masked_remaining_limit = tf.cast(remaining_cap_col_slots > 0,
tf.int64) * remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit = masked_remaining_limit // (
tf.cast(remaining_cap_col_slots <= 0, tf.int64) +
remaining_cap_col_slots) # +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if tf.math.reduce_sum(per_seg_limit) > 0:
remaining_slots_mat = tf.cast(remaining_cap > 0, tf.int64)
new_cap = new_cap + remaining_slots_mat * per_seg_limit
else:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask = tf.cast(
(tf.cumsum(masked_remaining_slots, axis=0) <= masked_remaining_limit)
& (masked_remaining_slots > 0), tf.int64)
new_cap = new_cap + new_remained_assignment_mask
break
return new_cap
def round_robin_truncate_inputs(
inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
limit: Union[int, tf.Tensor],
) -> Union[tf.RaggedTensor, List[tf.RaggedTensor]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if not isinstance(inputs, (list, tuple)):
return round_robin_truncate_inputs([inputs], limit)[0]
limit = tf.cast(limit, tf.int64)
if not all(rt.shape.rank == 2 for rt in inputs):
raise ValueError("All inputs must have shape [batch_size, (items)]")
if len(inputs) == 1:
return [_truncate_row_lengths(inputs[0], limit)]
elif len(inputs) == 2:
size_a, size_b = [rt.row_lengths() for rt in inputs]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half = limit // 2
ceil_half = limit - floor_half
quota_a = tf.minimum(size_a, ceil_half + tf.nn.relu(floor_half - size_b))
quota_b = tf.minimum(size_b, floor_half + tf.nn.relu(ceil_half - size_a))
return [_truncate_row_lengths(inputs[0], quota_a),
_truncate_row_lengths(inputs[1], quota_b)]
else:
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity = tf.stack([rt.row_lengths() for rt in inputs]) # #Segments x B
new_capacity = _iterative_vectorized_fair_share(capacity, limit)
return [
_truncate_row_lengths(inputs[i], new_capacity[i])
for i in range(capacity.shape[0])
]
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
new_lengths: tf.Tensor) -> tf.RaggedTensor:
"""Truncates the rows of `ragged_tensor` to the given row lengths."""
new_lengths = tf.broadcast_to(new_lengths,
ragged_tensor.bounding_shape()[0:1])
def fn(x):
row, new_length = x
return row[0:new_length]
fn_dtype = tf.RaggedTensorSpec(dtype=ragged_tensor.dtype,
ragged_rank=ragged_tensor.ragged_rank - 1)
result = tf.map_fn(fn, (ragged_tensor, new_lengths), dtype=fn_dtype)
# Work around broken shape propagation: without this, result has unknown rank.
flat_values_shape = [None] * ragged_tensor.flat_values.shape.rank
result = result.with_flat_values(
tf.ensure_shape(result.flat_values, flat_values_shape))
return result
class BertTokenizer(tf.keras.layers.Layer):
"""Wraps BertTokenizer with pre-defined vocab as a Keras Layer.
Attributes:
tokenize_with_offsets: If true, calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)`.
raw_table_access: An object with methods `.lookup(keys) and `.size()`
that operate on the raw lookup table of tokens. It can be used to
look up special token synbols like `[MASK]`.
"""
def __init__(self, *,
vocab_file: str,
lower_case: bool,
tokenize_with_offsets: bool = False,
**kwargs):
"""Initialize a `BertTokenizer` layer.
Args:
vocab_file: A Python string with the path of the vocabulary file.
This is a text file with newline-separated wordpiece tokens.
This layer initializes a lookup table from it that gets used with
`text.BertTokenizer`.
lower_case: A Python boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which
the `vocab_file` was created.
tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)`
insead of just tokens.
**kwargs: Standard arguments to `Layer()`.
Raises:
ImportError: If importing `tensorflow_text` failed.
"""
_check_if_tf_text_installed()
self.tokenize_with_offsets = tokenize_with_offsets
# TODO(b/177326279): Stop storing the vocab table initializer as an
# attribute when https://github.com/tensorflow/tensorflow/issues/46456
# has been fixed in the TensorFlow versions of the TF Hub users that load
# a SavedModel created from this layer. Due to that issue, loading such a
# SavedModel forgets to add .vocab_table._initializer as a trackable
# dependency of .vocab_table, so that saving it again to a second SavedModel
# (e.g., the final model built using TF Hub) does not properly track
# the ._vocab_table._initializer._filename as an Asset.
self._vocab_table, self._vocab_initializer_donotuse = (
self._create_vocab_table_and_initializer(vocab_file))
self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file)
super().__init__(**kwargs)
self._bert_tokenizer = text.BertTokenizer(
self._vocab_table, lower_case=lower_case)
@property
def vocab_size(self):
return self._vocab_table.size()
def _create_vocab_table_and_initializer(self, vocab_file):
vocab_initializer = tf.lookup.TextFileInitializer(
vocab_file,
key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
return vocab_table, vocab_initializer
def call(self, inputs: tf.Tensor):
"""Calls `text.BertTokenizer` on inputs.
Args:
inputs: A string Tensor of shape `(batch_size,)`.
Returns:
One or three of `RaggedTensors` if `tokenize_with_offsets` is False or
True, respectively. These are
tokens: A `RaggedTensor` of shape
`[batch_size, (words), (pieces_per_word)]`
and type int32. `tokens[i,j,k]` contains the k-th wordpiece of the
j-th word in the i-th input.
start_offsets, limit_offsets: If `tokenize_with_offsets` is True,
RaggedTensors of type int64 with the same indices as tokens.
Element `[i,j,k]` contains the byte offset at the start, or past the
end, resp., for the k-th wordpiece of the j-th word in the i-th input.
"""
# Prepare to reshape the result to work around broken shape inference.
batch_size = tf.shape(inputs)[0]
def _reshape(rt):
values = rt.values
row_splits = rt.row_splits
row_splits = tf.reshape(row_splits, [batch_size + 1])
return tf.RaggedTensor.from_row_splits(values, row_splits)
# Call the tokenizer.
if self.tokenize_with_offsets:
tokens, start_offsets, limit_offsets = (
self._bert_tokenizer.tokenize_with_offsets(inputs))
tokens = tf.cast(tokens, dtype=tf.int32)
return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets)
else:
tokens = self._bert_tokenizer.tokenize(inputs)
tokens = tf.cast(tokens, dtype=tf.int32)
return _reshape(tokens)
def get_config(self):
# Skip in tf.saved_model.save(); fail if called direcly.
raise NotImplementedError("TODO(b/170480226): implement")
def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose.
Returns:
A dict from Python strings to Python integers. Each key is a standard
name for a special token describing its use. (For example, "padding_id"
is what BERT traditionally calls "[PAD]" but others may call "<pad>".)
The corresponding value is the integer token id. If a special token
is not found, its entry is omitted from the dict.
The supported keys and tokens are:
* start_of_sequence_id: looked up from "[CLS]"
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up form "[PAD]"
* mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
"""
return self._special_tokens_dict
def _create_special_tokens_dict(self, vocab_table, vocab_file):
special_tokens = dict(start_of_sequence_id="[CLS]",
end_of_segment_id="[SEP]",
padding_id="[PAD]",
mask_id="[MASK]")
with tf.init_scope():
if tf.executing_eagerly():
special_token_ids = vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
vocab_size = vocab_table.size()
else:
# A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
logging.warning(
"Non-eager init context; computing "
"BertTokenizer's special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default():
local_vocab_table, _ = self._create_vocab_table_and_initializer(
vocab_file)
special_token_ids_tensor = local_vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
vocab_size_tensor = local_vocab_table.size()
init_ops = [tf.compat.v1.initialize_all_tables()]
with tf.compat.v1.Session() as sess:
sess.run(init_ops)
special_token_ids, vocab_size = sess.run(
[special_token_ids_tensor, vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for k, v in zip(special_tokens, special_token_ids):
v = int(v)
if v >= 0:
result[k] = v
else:
logging.warning("Could not find %s as token \"%s\" in vocab file %s",
k, special_tokens[k], vocab_file)
return result
class SentencepieceTokenizer(tf.keras.layers.Layer):
"""Wraps `tf_text.SentencepieceTokenizer` as a Keras Layer.
Attributes:
tokenize_with_offsets: If true, calls
`SentencepieceTokenizer.tokenize_with_offsets()`
instead of plain `.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)`.
"""
def __init__(self,
*,
lower_case: bool,
model_file_path: Optional[str] = None,
model_serialized_proto: Optional[str] = None,
tokenize_with_offsets: bool = False,
nbest_size: int = 0,
alpha: float = 1.0,
strip_diacritics: bool = False,
**kwargs):
"""Initializes a SentencepieceTokenizer layer.
Args:
lower_case: A Python boolean indicating whether to lowercase the string
before tokenization. NOTE: New models are encouraged to build `*_cf`
(case folding) normalization into the Sentencepiece model itself and
avoid this extra step.
model_file_path: A Python string with the path of the sentencepiece model.
Exactly one of `model_file_path` and `model_serialized_proto` can be
specified. In either case, the Keras model config for this layer will
store the actual proto (not a filename passed here).
model_serialized_proto: The sentencepiece model serialized proto string.
tokenize_with_offsets: A Python boolean. If true, this layer calls
`SentencepieceTokenizer.tokenize_with_offsets()` instead of
plain `.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)` insead of just tokens.
Note that when following `strip_diacritics` is set to True, returning
offsets is not supported now.
nbest_size: A scalar for sampling:
nbest_size = {0,1}: No sampling is performed. (default)
nbest_size > 1: samples from the nbest_size results.
nbest_size < 0: assuming that nbest_size is infinite and samples
from the all hypothesis (lattice) using
forward-filtering-and-backward-sampling algorithm.
alpha: A scalar for a smoothing parameter. Inverse temperature for
probability rescaling.
strip_diacritics: Whether to strip diacritics or not. Note that stripping
diacritics requires additional text normalization and dropping bytes,
which makes it impossible to keep track of the offsets now. Hence
when `strip_diacritics` is set to True, we don't yet support
`tokenize_with_offsets`. NOTE: New models are encouraged to put this
into custom normalization rules for the Sentencepiece model itself to
avoid this extra step and the limitation regarding offsets.
**kwargs: standard arguments to `Layer()`.
Raises:
ImportError: if importing tensorflow_text failed.
"""
_check_if_tf_text_installed()
super().__init__(**kwargs)
if bool(model_file_path) == bool(model_serialized_proto):
raise ValueError("Exact one of `model_file_path` and "
"`model_serialized_proto` can be specified.")
# TODO(b/181866850): Support tokenize_with_offsets for strip_diacritics=True
if tokenize_with_offsets and strip_diacritics:
raise ValueError("`tokenize_with_offsets` is not supported when "
"`strip_diacritics` is set to True.")
if model_file_path:
self._model_serialized_proto = tf.io.gfile.GFile(model_file_path,
"rb").read()
else:
self._model_serialized_proto = model_serialized_proto
self._lower_case = lower_case
self.tokenize_with_offsets = tokenize_with_offsets
self._nbest_size = nbest_size
self._alpha = alpha
self._strip_diacritics = strip_diacritics
self._tokenizer = self._create_tokenizer()
self._special_tokens_dict = self._create_special_tokens_dict()
def _create_tokenizer(self):
return text.SentencepieceTokenizer(
model=self._model_serialized_proto,
out_type=tf.int32,
nbest_size=self._nbest_size,
alpha=self._alpha)
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def call(self, inputs: tf.Tensor):
"""Calls `text.SentencepieceTokenizer` on inputs.
Args:
inputs: A string Tensor of shape `(batch_size,)`.
Returns:
One or three of RaggedTensors if tokenize_with_offsets is False or True,
respectively. These are
tokens: A RaggedTensor of shape `[batch_size, (pieces)]` and type `int32`.
`tokens[i,j]` contains the j-th piece in the i-th input.
start_offsets, limit_offsets: If `tokenize_with_offsets` is True,
RaggedTensors of type `int64` with the same indices as tokens.
Element `[i,j]` contains the byte offset at the start, or past the
end, resp., for the j-th piece in the i-th input.
"""
if self._strip_diacritics:
if self.tokenize_with_offsets:
raise ValueError("`tokenize_with_offsets` is not supported yet when "
"`strip_diacritics` is set to True (b/181866850).")
inputs = text.normalize_utf8(inputs, "NFD")
inputs = tf.strings.regex_replace(inputs, r"\p{Mn}", "")
if self._lower_case:
inputs = text.case_fold_utf8(inputs)
# Prepare to reshape the result to work around broken shape inference.
batch_size = tf.shape(inputs)[0]
def _reshape(rt):
values = rt.values
row_splits = rt.row_splits
row_splits = tf.reshape(row_splits, [batch_size + 1])
return tf.RaggedTensor.from_row_splits(values, row_splits)
# Call the tokenizer.
if self.tokenize_with_offsets:
tokens, start_offsets, limit_offsets = (
self._tokenizer.tokenize_with_offsets(inputs))
return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets)
else:
tokens = self._tokenizer.tokenize(inputs)
return _reshape(tokens)
def get_config(self):
# Skip in tf.saved_model.save(); fail if called direcly.
raise NotImplementedError("TODO(b/170480226): implement")
def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose.
Returns:
A dict from Python strings to Python integers. Each key is a standard
name for a special token describing its use. (For example, "padding_id"
is what Sentencepiece calls "<pad>" but others may call "[PAD]".)
The corresponding value is the integer token id. If a special token
is not found, its entry is omitted from the dict.
The supported keys and tokens are:
* start_of_sequence_id: looked up from "[CLS]"
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up from "<pad>"
* mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
"""
return self._special_tokens_dict
def _create_special_tokens_dict(self):
special_tokens = dict(
start_of_sequence_id=b"[CLS]",
end_of_segment_id=b"[SEP]",
padding_id=b"<pad>",
mask_id=b"[MASK]")
with tf.init_scope():
if tf.executing_eagerly():
special_token_ids = self._tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens = self._tokenizer.id_to_string(special_token_ids)
vocab_size = self._tokenizer.vocab_size()
else:
# A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
logging.warning(
"Non-eager init context; computing SentencepieceTokenizer's "
"special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default():
local_tokenizer = self._create_tokenizer()
special_token_ids_tensor = local_tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens_tensor = local_tokenizer.id_to_string(
special_token_ids_tensor)
vocab_size_tensor = local_tokenizer.vocab_size()
with tf.compat.v1.Session() as sess:
special_token_ids, inverse_tokens, vocab_size = sess.run(
[special_token_ids_tensor, inverse_tokens_tensor,
vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for name, token_id, inverse_token in zip(special_tokens,
special_token_ids,
inverse_tokens):
if special_tokens[name] == inverse_token:
result[name] = int(token_id)
else:
logging.warning(
"Could not find %s as token \"%s\" in sentencepiece model, "
"got \"%s\"", name, special_tokens[name], inverse_token)
return result
class BertPackInputs(tf.keras.layers.Layer):
"""Packs tokens into model inputs for BERT."""
def __init__(self,
seq_length,
*,
start_of_sequence_id=None,
end_of_segment_id=None,
padding_id=None,
special_tokens_dict=None,
truncator="round_robin",
**kwargs):
"""Initializes with a target `seq_length`, relevant token ids and truncator.
Args:
seq_length: The desired output length. Must not exceed the max_seq_length
that was fixed at training time for the BERT model receiving the inputs.
start_of_sequence_id: The numeric id of the token that is to be placed
at the start of each sequence (called "[CLS]" for BERT).
end_of_segment_id: The numeric id of the token that is to be placed
at the end of each input segment (called "[SEP]" for BERT).
padding_id: The numeric id of the token that is to be placed into the
unused positions after the last segment in the sequence
(called "[PAD]" for BERT).
special_tokens_dict: Optionally, a dict from Python strings to Python
integers that contains values for `start_of_sequence_id`,
`end_of_segment_id` and `padding_id`. (Further values in the dict are
silenty ignored.) If this is passed, separate *_id arguments must be
omitted.
truncator: The algorithm to truncate a list of batched segments to fit a
per-example length limit. The value can be either `round_robin` or
`waterfall`:
(1) For "round_robin" algorithm, available space is assigned
one token at a time in a round-robin fashion to the inputs that still
need some, until the limit is reached. It currently only supports
one or two segments.
(2) For "waterfall" algorithm, the allocation of the budget is done
using a "waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run out of
budget. It support arbitrary number of segments.
**kwargs: standard arguments to `Layer()`.
Raises:
ImportError: if importing `tensorflow_text` failed.
"""
_check_if_tf_text_installed()
super().__init__(**kwargs)
self.seq_length = seq_length
if truncator not in ("round_robin", "waterfall"):
raise ValueError("Only 'round_robin' and 'waterfall' algorithms are "
"supported, but got %s" % truncator)
self.truncator = truncator
self._init_token_ids(
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id,
padding_id=padding_id,
special_tokens_dict=special_tokens_dict)
def _init_token_ids(
self, *,
start_of_sequence_id,
end_of_segment_id,
padding_id,
special_tokens_dict):
usage = ("Must pass either all of start_of_sequence_id, end_of_segment_id, "
"padding_id as arguments, or else a special_tokens_dict "
"with those keys.")
special_tokens_args = [start_of_sequence_id, end_of_segment_id, padding_id]
if special_tokens_dict is None:
if any(x is None for x in special_tokens_args):
return ValueError(usage)
self.start_of_sequence_id = int(start_of_sequence_id)
self.end_of_segment_id = int(end_of_segment_id)
self.padding_id = int(padding_id)
else:
if any(x is not None for x in special_tokens_args):
return ValueError(usage)
self.start_of_sequence_id = int(
special_tokens_dict["start_of_sequence_id"])
self.end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
self.padding_id = int(special_tokens_dict["padding_id"])
def get_config(self) -> Dict[str, Any]:
config = super().get_config()
config["seq_length"] = self.seq_length
config["start_of_sequence_id"] = self.start_of_sequence_id
config["end_of_segment_id"] = self.end_of_segment_id
config["padding_id"] = self.padding_id
config["truncator"] = self.truncator
return config
def call(self, inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]]):
"""Adds special tokens to pack a list of segments into BERT input Tensors.
Args:
inputs: A Python list of one or two RaggedTensors, each with the batched
values one input segment. The j-th segment of the i-th input example
consists of slice `inputs[j][i, ...]`.
Returns:
A nest of Tensors for use as input to the BERT TransformerEncoder.
"""
# BertPackInputsSavedModelWrapper relies on only calling bert_pack_inputs()
return BertPackInputs.bert_pack_inputs(
inputs, self.seq_length,
start_of_sequence_id=self.start_of_sequence_id,
end_of_segment_id=self.end_of_segment_id,
padding_id=self.padding_id,
truncator=self.truncator)
@staticmethod
def bert_pack_inputs(inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
seq_length: Union[int, tf.Tensor],
start_of_sequence_id: Union[int, tf.Tensor],
end_of_segment_id: Union[int, tf.Tensor],
padding_id: Union[int, tf.Tensor],
truncator="round_robin"):
"""Freestanding equivalent of the BertPackInputs layer."""
_check_if_tf_text_installed()
# Sanitize inputs.
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
if not inputs:
raise ValueError("At least one input is required for packing")
input_ranks = [rt.shape.rank for rt in inputs]
if None in input_ranks or len(set(input_ranks)) > 1:
raise ValueError("All inputs for packing must have the same known rank, "
"found ranks " + ",".join(input_ranks))
# Flatten inputs to [batch_size, (tokens)].
if input_ranks[0] > 2:
inputs = [rt.merge_dims(1, -1) for rt in inputs]
# In case inputs weren't truncated (as they should have been),
# fall back to some ad-hoc truncation.
num_special_tokens = len(inputs) + 1
if truncator == "round_robin":
trimmed_segments = round_robin_truncate_inputs(
inputs, seq_length - num_special_tokens)
elif truncator == "waterfall":
trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(inputs)
else:
raise ValueError("Unsupported truncator: %s" % truncator)
# Combine segments.
segments_combined, segment_ids = text.combine_segments(
trimmed_segments,
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id)
# Pad to dense Tensors.
input_word_ids, _ = text.pad_model_inputs(segments_combined, seq_length,
pad_value=padding_id)
input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length,
pad_value=0)
# Work around broken shape inference.
output_shape = tf.stack([
inputs[0].nrows(out_type=tf.int32), # batch_size
tf.cast(seq_length, dtype=tf.int32)])
def _reshape(t):
return tf.reshape(t, output_shape)
# Assemble nest of input tensors as expected by BERT TransformerEncoder.
return dict(input_word_ids=_reshape(input_word_ids),
input_mask=_reshape(input_mask),
input_type_ids=_reshape(input_type_ids))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests bert.text_layers."""
import os
import tempfile
import numpy as np
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers
class RoundRobinTruncatorTest(tf.test.TestCase):
def _test_input(self, start, lengths):
return tf.ragged.constant([[start + 10 * j + i
for i in range(length)]
for j, length in enumerate(lengths)],
dtype=tf.int32)
def test_single_segment(self):
# Single segment.
single_input = self._test_input(11, [4, 5, 6])
expected_single_output = tf.ragged.constant(
[[11, 12, 13, 14],
[21, 22, 23, 24, 25],
[31, 32, 33, 34, 35], # Truncated.
])
self.assertAllEqual(
expected_single_output,
text_layers.round_robin_truncate_inputs(single_input, limit=5))
# Test wrapping in a singleton list.
actual_single_list_output = text_layers.round_robin_truncate_inputs(
[single_input], limit=5)
self.assertIsInstance(actual_single_list_output, list)
self.assertAllEqual(expected_single_output, actual_single_list_output[0])
def test_two_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5])
expected_a = tf.ragged.constant(
[[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153], # Truncated.
[161, 162, 163], # Truncated.
])
expected_b = tf.ragged.constant(
[[211],
[221, 222, 223],
[231, 232, 233], # Truncated.
[241, 242],
[251, 252],
[261, 262], # Truncated.
])
actual_a, actual_b = text_layers.round_robin_truncate_inputs(
[input_a, input_b], limit=5)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
def test_three_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5, 1])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5, 8])
input_c = self._test_input(311, [1, 3, 4, 2, 2, 5, 10])
seg_limit = 8
expected_a = tf.ragged.constant([
[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153, 154],
[161, 162, 163], # Truncated
[171]
])
expected_b = tf.ragged.constant([
[211],
[221, 222, 223],
[231, 232, 233], # Truncated
[241, 242],
[251, 252],
[261, 262, 263], # Truncated
[271, 272, 273, 274] # Truncated
])
expected_c = tf.ragged.constant([
[311],
[321, 322, 323],
[331, 332, 333], # Truncated
[341, 342],
[351, 352],
[361, 362], # Truncated
[371, 372, 373] # Truncated
])
actual_a, actual_b, actual_c = text_layers.round_robin_truncate_inputs(
[input_a, input_b, input_c], limit=seg_limit)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
self.assertAllEqual(expected_c, actual_c)
input_cap = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [input_a, input_b, input_c]]),
axis=0)
per_example_usage = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [actual_a, actual_b, actual_c]]),
axis=0)
self.assertTrue(all(per_example_usage <= tf.minimum(seg_limit, input_cap)))
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py.
class BertTokenizerTest(tf.test.TestCase):
def _make_vocab_file(self, vocab, filename="vocab.txt"):
path = os.path.join(
tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time.
filename)
with tf.io.gfile.GFile(path, "w") as f:
f.write("\n".join(vocab + [""]))
return path
def test_uncased(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = bert_tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
[[6], [4, 5], [4]]]))
bert_tokenize.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
[[0], [4, 5], [8]]]))
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
[[3], [5, 7], [9]]]))
self.assertEqual(bert_tokenize.vocab_size.numpy(), 8)
# Repeat the above and test that case matters with lower_case=False.
def test_cased(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True)
inputs = tf.constant(["abc def", "ABC DEF"])
token_ids, start_offsets, limit_offsets = bert_tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
[[7], [1]]]))
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
[[0], [4]]]))
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
[[3], [7]]]))
def test_special_tokens_complete(self):
vocab_file = self._make_vocab_file(
["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=1,
start_of_sequence_id=3,
end_of_segment_id=4,
mask_id=5,
vocab_size=7))
def test_special_tokens_partial(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[CLS]", "[SEP]"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=0,
start_of_sequence_id=1,
end_of_segment_id=2,
vocab_size=3)) # No mask_id,
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
def input_fn():
with tf.init_scope():
self.assertFalse(tf.executing_eagerly())
# Build a preprocessing Model.
sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
bert_tokenizer = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
special_tokens_dict = bert_tokenizer.get_special_tokens_dict()
for k, v in special_tokens_dict.items():
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
tokens = bert_tokenizer(sentences)
packed_inputs = text_layers.BertPackInputs(
4, special_tokens_dict=special_tokens_dict)(tokens)
preprocessing = tf.keras.Model(sentences, packed_inputs)
# Map the dataset.
ds = tf.data.Dataset.from_tensors(
(tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
return ds
def model_fn(features, labels, mode):
del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
[2, 4, 5, 3]]))
# This test covers the in-process behavior of a SentencepieceTokenizer layer.
class SentencepieceTokenizerTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# Make a sentencepiece model.
tmp_dir = self.get_temp_dir()
tempfile.mkdtemp(dir=tmp_dir)
vocab = ["a", "b", "c", "d", "e", "abc", "def", "ABC", "DEF"]
model_prefix = os.path.join(tmp_dir, "spm_model")
input_text_file_path = os.path.join(tmp_dir, "train_input.txt")
with tf.io.gfile.GFile(input_text_file_path, "w") as f:
f.write(" ".join(vocab + ["\n"]))
# Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
full_vocab_size = len(vocab) + 7
flags = dict(
model_prefix=model_prefix,
model_type="word",
input=input_text_file_path,
pad_id=0, unk_id=1, control_symbols="[CLS],[SEP],[MASK]",
vocab_size=full_vocab_size,
bos_id=full_vocab_size-2, eos_id=full_vocab_size-1)
SentencePieceTrainer.Train(
" ".join(["--{}={}".format(k, v) for k, v in flags.items()]))
self._spm_path = model_prefix + ".model"
def test_uncased(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = sentencepiece_tokenizer(inputs)
self.assertAllEqual(
token_ids,
tf.ragged.constant([[8, 12], [8, 12, 11]]))
sentencepiece_tokenizer.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = sentencepiece_tokenizer(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(
start_offsets, tf.ragged.constant([[0, 3], [0, 3, 7]]))
self.assertAllEqual(
limit_offsets, tf.ragged.constant([[3, 7], [3, 7, 9]]))
self.assertEqual(sentencepiece_tokenizer.vocab_size.numpy(), 16)
# Repeat the above and test that case matters with lower_case=False.
def test_cased(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=False,
nbest_size=0,
tokenize_with_offsets=False)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = sentencepiece_tokenizer(inputs)
self.assertAllEqual(
token_ids,
tf.ragged.constant([[8, 12], [5, 6, 11]]))
sentencepiece_tokenizer.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = sentencepiece_tokenizer(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(
start_offsets,
tf.ragged.constant([[0, 3], [0, 3, 7]]))
self.assertAllEqual(
limit_offsets,
tf.ragged.constant([[3, 7], [3, 7, 9]]))
def test_special_tokens(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
self.assertDictEqual(sentencepiece_tokenizer.get_special_tokens_dict(),
dict(padding_id=0,
start_of_sequence_id=2,
end_of_segment_id=3,
mask_id=4,
vocab_size=16))
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
def input_fn():
with tf.init_scope():
self.assertFalse(tf.executing_eagerly())
# Build a preprocessing Model.
sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
special_tokens_dict = sentencepiece_tokenizer.get_special_tokens_dict()
for k, v in special_tokens_dict.items():
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
tokens = sentencepiece_tokenizer(sentences)
packed_inputs = text_layers.BertPackInputs(
4, special_tokens_dict=special_tokens_dict)(tokens)
preprocessing = tf.keras.Model(sentences, packed_inputs)
# Map the dataset.
ds = tf.data.Dataset.from_tensors(
(tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
return ds
def model_fn(features, labels, mode):
del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 8, 3, 0],
[2, 12, 3, 0]]))
def test_strip_diacritics(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=True,
nbest_size=0,
strip_diacritics=True)
inputs = tf.constant(["a b c d e", "ă ḅ č ḓ é"])
token_ids = sentencepiece_tokenizer(inputs)
self.assertAllEqual(
token_ids,
tf.ragged.constant([[7, 9, 10, 11, 13], [7, 9, 10, 11, 13]]))
def test_fail_on_tokenize_with_offsets_and_strip_diacritics(self):
# Raise an error in init().
with self.assertRaises(ValueError):
text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
tokenize_with_offsets=True,
lower_case=True,
nbest_size=0,
strip_diacritics=True)
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=True,
nbest_size=0,
strip_diacritics=True)
sentencepiece_tokenizer.tokenize_with_offsets = True
# Raise an error in call():
inputs = tf.constant(["abc def", "ABC DEF d", "Äffin"])
with self.assertRaises(ValueError):
sentencepiece_tokenizer(inputs)
def test_serialize_deserialize(self):
self.skipTest("b/170480226")
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=False,
nbest_size=0,
tokenize_with_offsets=False,
name="sentencepiece_tokenizer_layer")
config = sentencepiece_tokenizer.get_config()
new_tokenizer = text_layers.SentencepieceTokenizer.from_config(config)
self.assertEqual(config, new_tokenizer.get_config())
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = sentencepiece_tokenizer(inputs)
token_ids_2 = new_tokenizer(inputs)
self.assertAllEqual(token_ids, token_ids_2)
# TODO(b/170480226): Remove once tf_hub_export_lib_test.py covers saving.
def test_saving(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
inputs = tf.keras.layers.Input([], dtype=tf.string)
outputs = sentencepiece_tokenizer(inputs)
model = tf.keras.Model(inputs, outputs)
export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
model.save(export_path, signatures={})
class BertPackInputsTest(tf.test.TestCase):
def test_round_robin_correct_outputs(self):
bpi = text_layers.BertPackInputs(
10,
start_of_sequence_id=1001,
end_of_segment_id=1002,
padding_id=999,
truncator="round_robin")
# Single input, rank 2.
bert_inputs = bpi(
tf.ragged.constant([[11, 12, 13],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
# Two inputs, rank 3. Truncation does not respect word boundaries.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211, 212], [213]],
[[221, 222], [223, 224, 225], [226, 227, 228]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
[1001, 121, 122, 123, 124, 1002, 221, 222, 223, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]))
# Three inputs. rank 3.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211, 212], [213]],
[[221, 222], [223, 224, 225], [226, 227, 228]]]),
tf.ragged.constant([[[311, 312], [313]],
[[321, 322], [323, 324, 325], [326, 327, 328]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 1002, 211, 212, 1002, 311, 312, 1002],
[1001, 121, 122, 1002, 221, 222, 1002, 321, 322, 1002]]))
def test_waterfall_correct_outputs(self):
bpi = text_layers.BertPackInputs(
10,
start_of_sequence_id=1001,
end_of_segment_id=1002,
padding_id=999,
truncator="waterfall")
# Single input, rank 2.
bert_inputs = bpi(
tf.ragged.constant([[11, 12, 13],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
# Two inputs, rank 3. Truncation does not respect word boundaries.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211, 212], [213]],
[[221, 222], [223, 224, 225], [226, 227, 228]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
[1001, 121, 122, 123, 124, 125, 126, 127, 1002, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]))
# Three inputs, rank 3. Truncation does not respect word boundaries.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211], [212]],
[[221, 222], [223, 224, 225], [226, 227, 228]]]),
tf.ragged.constant([[[311, 312], [313]],
[[321, 322], [323, 324, 325], [326, 327]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 1002, 311, 1002],
[1001, 121, 122, 123, 124, 125, 126, 1002, 1002, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 2, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]]))
def test_special_tokens_dict(self):
special_tokens_dict = dict(start_of_sequence_id=1001,
end_of_segment_id=1002,
padding_id=999,
extraneous_key=666)
bpi = text_layers.BertPackInputs(10,
special_tokens_dict=special_tokens_dict)
bert_inputs = bpi(
tf.ragged.constant([[11, 12, 13],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ExpandCondense tensor network layer used in TN-BERT."""
# pylint: disable=g-classes-have-attributes
from typing import List, Optional, Text, Any, Dict
import tensorflow as tf
Layer = tf.keras.layers.Layer
activations = tf.keras.activations
initializers = tf.keras.initializers
@tf.keras.utils.register_keras_serializable(package='Text')
class TNExpandCondense(Layer):
"""A TPU-optimized TensorNetwork layer.
Designed for use in models that currently use Dense layers to achieve
up projection followed by down projection.
This layer is a TPU-optimized combination of 3 operations:
Expand, Apply Activation, and Condense. The layer projects up from
`input_shape[-1]` to `input_shape[-1] * proj_multiplier`, applies
`self.activation`, and then condenses back to `input_shape[-1]`.
Note the input shape and output shape will be identical.
Args:
proj_multiplier: Positive integer, multiple of `input_shape[-1]` to project
up to. Must be one of `[2, 4, 6, 8]`.
use_bias: Boolean, whether the layer uses a bias vector.
activation: Activation function to use between Expand and Condense. If you
don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
kernel_initializer: Initializer for the weight matrices.
bias_initializer: Initializer for the bias vector.
Input shape:
N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
Output shape:
N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
"""
def __init__(self,
proj_multiplier: int,
use_bias: Optional[bool] = True,
activation: Optional[Text] = 'relu',
kernel_initializer: Optional[Text] = 'glorot_uniform',
bias_initializer: Optional[Text] = 'zeros',
**kwargs) -> None:
# Allow specification of input_dim instead of input_shape,
# for compatability with Keras layers that support this
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(TNExpandCondense, self).__init__(**kwargs)
assert proj_multiplier in [
2, 4, 6, 8, 10, 12
], 'proj_multiplier needs to be one of [2, 4, 6, 8, 10, 12]'
self.proj_multiplier = proj_multiplier
self.use_bias = use_bias
self.activation = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
def build(self, input_shape: List[int]) -> None:
# Disable the attribute-defined-outside-init violations in this function
# pylint: disable=attribute-defined-outside-init
if input_shape[-1] is None:
raise ValueError(
'The last dimension of the inputs to `TNExpandCondense` '
'should be defined. Found `None`.')
super(TNExpandCondense, self).build(input_shape)
self.proj_size = self.proj_multiplier * input_shape[-1]
assert (self.proj_size // input_shape[-1]) * input_shape[
-1] == self.proj_size, (f'{self.proj_size} / {input_shape[-1]} must be '
f'round')
assert (input_shape[-1] // 128
) * 128 == input_shape[-1], f'{input_shape[-1]} / 128 must be round'
self.w1 = self.add_weight(
name='w1',
shape=(input_shape[-1], input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
self.w2 = self.add_weight(
name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True,
initializer=self.kernel_initializer)
self.w3 = self.add_weight(
name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True,
initializer=self.kernel_initializer)
self.w4 = self.add_weight(
name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
if self.use_bias:
self.bias = self.add_weight(
name='b',
shape=(input_shape[-1] // 128, 1,
128 * (self.proj_size // input_shape[-1])),
trainable=True,
initializer=self.bias_initializer)
else:
self.bias = None
def call(self, inputs: tf.Tensor, **kwargs):
orig_shape = tf.shape(inputs)
input_dim = inputs.shape[-1]
tmp = tf.reshape(inputs, (-1, input_dim))
# Shape is (BatchSeq, input_dim)
# Expansion network
tmp = tf.einsum('ab,Qb->aQ', self.w1, tmp)
# Note: Letter Q will always represent the BatchSeq axis.
tmp = tf.reshape(tmp, (input_dim // 128, 128, -1))
tmp = tf.einsum('abQ,bd->aQd', tmp, self.w2)
# Apply activation and then Condense
tmp = self.activation(tmp + self.bias)
tmp = tf.einsum('aQd,db->aQb', tmp, self.w3)
tmp = tf.einsum('aQb,abd->Qd', tmp, self.w4)
out = tf.reshape(tmp, orig_shape)
return out
def compute_output_shape(self, input_shape: List[int]) -> List[int]:
return input_shape
def get_config(self) -> Dict[Any, Any]:
"""Returns the config of the layer.
The same layer can be reinstantiated later
(without its trained weights) from this configuration.
Returns:
Python dictionary containing the configuration of the layer.
"""
config = {}
# Include the layer-specific arguments
args = ['proj_multiplier', 'use_bias']
for arg in args:
config[arg] = getattr(self, arg)
# Serialize the activation
config['activation'] = activations.serialize(getattr(self, 'activation'))
# Serialize the initializers
decomp_initializers = ['kernel_initializer', 'bias_initializer']
for initializer_arg in decomp_initializers:
config[initializer_arg] = initializers.serialize(
getattr(self, initializer_arg))
# Get base config
base_config = super(TNExpandCondense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ExpandCondense tensor network layer."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.testing_utils import layer_test
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for ExpandCondense TN layer.
"""
def setUp(self):
super(TNLayerTest, self).setUp()
self.labels = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))), axis=0)
def _build_model(self, data, proj_multiple=2):
model = tf.keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
return model
@parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple):
self.skipTest('Disable the test for now since it imports '
'keras.testing_utils, will reenable this test after we '
'fix the b/184578869')
# TODO(scottzhu): Reenable after fix b/184578869
data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32)
layer_test(
TNExpandCondense,
kwargs={
'proj_multiplier': proj_multiple,
'input_shape': data.shape
},
input_shape=data.shape,
input_data=data,
expected_output_shape=(None, data.shape[-1]),
expected_output_dtype=data.dtype)
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
tf.random.set_seed(0)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
history = model.fit(data, self.labels, epochs=5, batch_size=32)
# Check that loss decreases and accuracy increases
self.assertGreater(history.history['loss'][0], history.history['loss'][-1])
self.assertLess(
history.history['accuracy'][0], history.history['accuracy'][-1])
@parameterized.parameters((768, 6), (1024, 2))
def test_weights_change(self, input_dim, proj_multiple):
tf.random.set_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
before = model.get_weights()
model.fit(data, self.labels, epochs=5, batch_size=32)
after = model.get_weights()
# Make sure every layer's weights changed
for i, _ in enumerate(before):
self.assertTrue((after[i] != before[i]).any())
@parameterized.parameters((768, 6), (1024, 2))
def test_output_shape(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
input_shape = data.shape
actual_output_shape = model(data).shape
expected_output_shape = model.compute_output_shape(input_shape)
self.assertEqual(expected_output_shape, actual_output_shape)
@parameterized.parameters((768, 6), (1024, 2))
def test_expandcondense_num_parameters(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
proj_size = proj_multiple * data.shape[-1]
model = tf.keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
w1_params = data.shape[-1]**2
w2_params = 128 * 128 * (proj_size // data.shape[-1])
w3_params = 128 * 128 * (proj_size // data.shape[-1])
w4_params = (data.shape[-1] // 128) * 128 * data.shape[-1]
bias_params = ((data.shape[-1] // 128) * 128 *
(proj_size // data.shape[-1]))
expected_num_parameters = (w1_params + w2_params + w3_params +
w4_params) + bias_params
self.assertEqual(expected_num_parameters, model.count_params())
@parameterized.parameters((912, 6), (200, 2))
def test_incorrect_sizes(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
with self.assertRaises(AssertionError):
model = self._build_model(data, proj_multiple)
model.compile(optimizer='adam', loss='binary_crossentropy')
@parameterized.parameters((768, 6), (1024, 2))
def test_config(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
expected_num_parameters = model.layers[0].count_params()
# Serialize model and use config to create new layer
model_config = model.get_config()
layer_config = model_config['layers'][1]['config']
new_model = TNExpandCondense.from_config(layer_config)
# Build the layer so we can count params below
new_model.build(layer_config['batch_input_shape'])
# Check that original layer had same num params as layer built from config
self.assertEqual(expected_num_parameters, new_model.count_params())
@parameterized.parameters((768, 6), (1024, 2))
def test_model_save(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
model.fit(data, self.labels, epochs=5, batch_size=32)
save_path = os.path.join(self.get_temp_dir(), 'test_model')
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TN-BERT TNTransformerExpandCondense employing Expand-Condense layer instead of Dense."""
# pylint: disable=g-classes-have-attributes
# Import libraries
import gin
import tensorflow as tf
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TNTransformerExpandCondense(tf.keras.layers.Layer):
"""Transformer layer using tensor network Expand-Condense layer.
This layer implements the Transformer from transformer.py, with a single
tensor network layer replacing the usual intermediate and output Dense
layers.
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set to False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(TNTransformerExpandCondense, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError(
"TNTransformerExpandCondense expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError(
"When passing a mask tensor to TNTransformerExpandCondense, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
# Substitute Dense layers with a single Expand-Condense layer.
self._output_dense = TNExpandCondense(
4,
use_bias=True,
activation=self._intermediate_activation,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(TNTransformerExpandCondense, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super(TNTransformerExpandCondense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
layer_output = self._output_dense(attention_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TN-BERT transformer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(('tn', TNTransformerExpandCondense))
class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (16 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 256
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,30 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based transformer block layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import gin
import tensorflow as tf
from official.nlp import keras_nlp
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager
@tf.keras.utils.register_keras_serializable(package="Text")
class Transformer(tf.keras.layers.Layer):
class Transformer(keras_nlp.layers.TransformerEncoderBlock):
"""Transformer layer.
This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762).
Arguments:
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
......@@ -49,6 +46,15 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def __init__(self,
......@@ -65,161 +71,32 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(Transformer, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3)
self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
super(Transformer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
super().__init__(
num_attention_heads=num_attention_heads,
inner_dim=intermediate_size,
inner_activation=intermediate_activation,
output_dropout=dropout_rate,
attention_dropout=attention_dropout_rate,
output_range=output_range,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
use_bias=use_bias,
norm_first=norm_first,
norm_epsilon=norm_epsilon,
inner_dropout=intermediate_dropout,
attention_initializer=attention_initializer,
**kwargs)
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -228,11 +105,11 @@ class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True)
def call(self, inputs):
return super(CompiledTransformer, self).call(inputs)
return super().call(inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerDecoderLayer(tf.keras.layers.Layer):
class TransformerDecoderBlock(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
......@@ -240,7 +117,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
Arguments:
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
......@@ -255,6 +132,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def __init__(self,
......@@ -271,8 +157,13 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
super().__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
......@@ -287,6 +178,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
......@@ -294,7 +194,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape) != 3:
if len(target_tensor_shape.as_list()) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
hidden_size = target_tensor_shape[2]
......@@ -302,9 +202,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads)
self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -314,27 +213,35 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype="float32"))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="attention/encdec",
**common_kwargs)
......@@ -342,27 +249,77 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype="float32"))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderLayer, self).build(input_shape)
name="output_layer_norm", axis=-1,
epsilon=self._norm_epsilon, dtype="float32")
super().build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
......@@ -375,36 +332,58 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention(
self_attention_inputs,
query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory]
if self._norm_first:
self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict(
query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
if self._norm_first:
attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output + attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based transformer scaffold layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import logging
import gin
import tensorflow as tf
......@@ -38,7 +35,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
instantiate the class with the config, or pass a class instance to
`attention_cls`/`feedforward_cls`.
Arguments:
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
......@@ -46,28 +43,25 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance or None. If `attention_cls` is a
class, but `attention_cfg` is None, following kwargs will be used to
instantiate the attention instance:
{
instantiate the attention instance: {
"num_heads": num_attention_heads,
"key_size": int(hidden_size // num_attention_heads),
"key_dim": int(hidden_size // num_attention_heads),
"dropout": attention_dropout_rate,
"name": "self_attention"
}, where `hidden_size` is the input tensor's last dimension.
"name": "self_attention" }, where `hidden_size` is the input tensor's
last dimension.
feedforward_cls: A class to instantiate feedforward layer, or a layer
instance. If None, will use the standard feedforward layer as described
in "Attention Is All You Need" paper. If not None, the instantiated
feedforward layer is expected to take the output of attention as input
and its output is this transformer layer's output.
instance. If None, will use the standard feedforward layer as described in
"Attention Is All You Need" paper. If not None, the instantiated
feedforward layer is expected to take the output of attention as input and
its output is this transformer layer's output.
feedforward_cfg: The config with which to instantiate `feedforward_cls`.
Ignored if feedforward_cls is a layer instance or is None.
If `feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance:
{
Ignored if feedforward_cls is a layer instance or is None. If
`feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance: {
"intermediate_size": intermediate_size,
"intermediate_activation": intermediate_activation,
"dropout": dropout_rate,
"name": "feedforward"
}.
"name": "feedforward" }.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels.
......@@ -89,6 +83,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
feedforward_cfg=None,
dropout_rate=0.0,
attention_dropout_rate=0.0,
norm_first=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -103,6 +98,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_cls = attention_cls
self._feedforward_cls = feedforward_cls
self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
......@@ -116,24 +112,14 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape) != 3:
input_tensor_shape = input_shape[0] if (
len(input_shape) == 2) else input_shape
input_tensor_shape = tf.TensorShape(input_tensor_shape)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError(
"TransformerScaffold expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
......@@ -160,7 +146,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
default_attention_cfg = {
"num_heads": self._num_heads,
"key_size": self._attention_head_size,
"key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate,
"name": "self_attention"
}
......@@ -185,12 +171,16 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
self._feedforward_block = None
# self._dropout_rate controls dropout rates at two places:
# after attention, and after FFN.
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
if self._feedforward_block is None:
......@@ -200,7 +190,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
bias_axes="d",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
......@@ -221,6 +211,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
super(TransformerScaffold, self).build(input_shape)
logging.info("%s configs: %s", self.__class__.__name__, self.get_config())
def get_config(self):
config = {
......@@ -238,6 +229,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"norm_first":
self._norm_first,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
......@@ -256,30 +249,57 @@ class TransformerScaffold(tf.keras.layers.Layer):
base_config = super(TransformerScaffold, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
def call(self, inputs, training=None):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor, training=training)
attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask,
training=training)
attention_output = self._attention_dropout(attention_output,
training=training)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(input_tensor +
attention_output,
training=training)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output,
training=training)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
layer_output = self._output_dense(intermediate_output, training=training)
layer_output = self._output_dropout(layer_output, training=training)
# During mixed precision training, attention_output is from layer norm
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output,
training=training)
else:
layer_output = self._feedforward_block(attention_output)
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output,
training=training)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output,
training=training)
return layer_output
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,14 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based transformer block layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
"""Tests for Keras-based transformer block layer."""
import numpy as np
import tensorflow as tf
......@@ -39,10 +33,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list
def call(self, inputs, attention_mask=None):
def call(self, query, value, attention_mask=None):
self.list.append(True)
return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask)
query, value, attention_mask=attention_mask)
def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config()
......@@ -89,7 +83,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self):
sequence_length = 21
......@@ -98,7 +92,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -126,7 +120,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
feedforward_call_list = []
......@@ -164,7 +158,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -186,30 +180,6 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_creation_with_incorrect_mask_fails(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self):
sequence_length = 21
width = 80
......@@ -217,7 +187,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -252,7 +222,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
feedforward_call_list = []
......@@ -303,7 +273,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -338,14 +308,14 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy('mixed_float16')
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -386,7 +356,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -414,7 +384,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
'name': 'test_layer',
}
......@@ -447,12 +417,11 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config())
post_string_serialized_data = json.loads(serialized_data)
serialized_data = model.get_config()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data)
new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
......@@ -474,7 +443,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
'name': 'test_layer',
}
......@@ -512,14 +481,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length))
pre_serialization_output = model.predict([input_data, mask_data])
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config())
post_string_serialized_data = json.loads(serialized_data)
serialized_data = model.get_config()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data)
new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,210 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based transformer block layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for Keras-based transformer block layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(('base', transformer.Transformer),
('xla', transformer.CompiledTransformer))
class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
'key':
......@@ -227,12 +32,12 @@ def _create_cache(batch_size, init_decode_length, num_heads, head_size):
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
class TransformerDecoderBlockTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -248,6 +53,47 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
output, _ = decoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderBlock.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based Transformer XL layer."""
from absl import logging
import tensorflow as tf
from official.nlp.modeling.layers import relative_attention
def _cache_memory(current_state, previous_state, memory_length, reuse_length=0):
"""Caches hidden states into memory.
Args:
current_state: `Tensor`, the current state.
previous_state: `Tensor`, the previous state.
memory_length: `int`, the number of tokens to cache.
reuse_length: `int`, the number of tokens in the current batch to be cached
and reused in the future.
Returns:
A `Tensor`, representing the cached state with stopped gradients.
"""
if memory_length is None or memory_length == 0:
return None
else:
if reuse_length > 0:
current_state = current_state[:, :reuse_length, :]
if previous_state is None:
new_mem = current_state[:, -memory_length:, :]
else:
new_mem = tf.concat(
[previous_state, current_state], 1)[:, -memory_length:, :]
return tf.stop_gradient(new_mem)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerXLBlock(tf.keras.layers.Layer):
"""Transformer XL block.
This implements a Transformer XL block from "Transformer-XL: Attentive
Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This block is further extended to allow for the Transformer-XL
re-parameterization in "XLNet: Generalized Autoregressive Pretraining for
Language Understanding" (https://arxiv.org/abs/1906.08237).
Given an input stream, this block computes attention, applies dropouts and
layer norms and feeds into the FFN network.
**Note: This layer is currently experimental.
Attributes:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The inner size for the transformer layers.
dropout_rate: Dropout rate for the output of this layer.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used in the
XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
norm_epsilon: Epsilon value to initialize normalization layers.
inner_activation: The activation to use for the inner
FFN layers.
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
layer.
"""
def __init__(self,
vocab_size,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
two_stream=False,
norm_epsilon=1e-12,
inner_activation="relu",
kernel_initializer="variance_scaling",
inner_dropout=0.0,
**kwargs):
"""Initializes TransformerXLBlock layer."""
super(TransformerXLBlock, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._num_heads = num_attention_heads
self._head_size = head_size
self._hidden_size = hidden_size
self._inner_size = inner_size
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._inner_activation = inner_activation
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._inner_dropout = inner_dropout
self._two_stream = two_stream
if two_stream:
self._attention_layer_type = relative_attention.TwoStreamRelativeAttention
else:
self._attention_layer_type = relative_attention.MultiHeadRelativeAttention
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerXLBlock, "
"the mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_layer = self._attention_layer_type(
num_heads=self._num_heads,
key_dim=self._head_size,
value_dim=self._head_size,
dropout=self._attention_dropout_rate,
use_bias=False,
kernel_initializer=self._kernel_initializer,
name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate)
self._attention_layer_norm = tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
self._inner_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._inner_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="inner")
self._inner_activation_layer = tf.keras.layers.Activation(
self._inner_activation)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon)
super(TransformerXLBlock, self).build(input_shape)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"two_stream":
self._two_stream,
"norm_epsilon":
self._norm_epsilon,
"inner_activation":
self._inner_activation,
"kernel_initializer":
self._kernel_initializer,
"inner_dropout":
self._inner_dropout,
}
base_config = super(TransformerXLBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
content_attention_bias,
positional_attention_bias,
relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements `call` for the Layer.
Args:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
positional_attention_bias: Bias `Tensor` for position based attention of
shape `[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_encoding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
segment_attention_bias: Optional bias `Tensor` for segment based attention
of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A `dict` object, containing the key value pairs for `content_attention`
and (if `two_stream` is `True`) `query_attention`.
"""
if not self._two_stream and query_stream is not None:
logging.warning("`query_stream` was provided but two stream attention is "
"disabled. `query_stream` will be ignored.")
if self._two_stream:
attention_kwargs = dict(
content_stream=content_stream,
query_stream=query_stream,
query_attention_mask=query_attention_mask,
target_mapping=target_mapping,
content_attention_mask=content_attention_mask)
else:
attention_kwargs = dict(
query=content_stream,
value=content_stream,
key=content_stream,
attention_mask=content_attention_mask)
common_attention_kwargs = dict(
content_attention_bias=content_attention_bias,
relative_position_encoding=relative_position_encoding,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state)
attention_kwargs.update(common_attention_kwargs)
attention_output = self._attention_layer(**attention_kwargs)
if self._two_stream:
attention_streams = attention_output
input_streams = [content_stream, query_stream]
else:
attention_streams = [attention_output]
input_streams = [content_stream]
attention_keys = ["content_attention", "query_attention"]
attention_output = {}
for attention_stream, input_stream, attention_key in zip(
attention_streams, input_streams, attention_keys):
attention_stream = self._attention_dropout(attention_stream)
attention_stream = self._attention_layer_norm(
attention_stream + input_stream)
inner_output = self._inner_dense(attention_stream)
inner_output = self._inner_activation_layer(
inner_output)
inner_output = self._inner_dropout_layer(
inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
layer_output = self._output_layer_norm(layer_output + attention_stream)
attention_output[attention_key] = layer_output
return attention_output
class TransformerXL(tf.keras.layers.Layer):
"""Transformer XL.
This layer combines multiple Transformer XL blocks from "Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This layer handles the attention biases as well as memory caching and reuse
as in Transformer XL and XLNet.
Attributes:
vocab_size: The number of tokens in vocabulary.
num_layers: The number of layers.
hidden_size: The hidden size.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The hidden size in feed-forward layers.
dropout_rate: Dropout rate used in each Transformer XL block.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
initializer: The initializer to use for attention biases.
tie_attention_biases: Whether or not to tie biases together. If `True`, then
each Transformer XL block shares the same trainable attention bias. If
`False`, then each block has its own attention bias. This is usually set
to `True`.
memory_length: The number of tokens to cache.
reuse_length: The number of tokens in the current batch to be cached
and reused in the future.
inner_activation: The activation to use in the inner layers
for Transformer XL blocks. Typically "relu" or "gelu".
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
reuse_length=None,
inner_activation="relu",
**kwargs):
"""Initializes TransformerXL."""
super(TransformerXL, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
if self._tie_attention_biases:
attention_bias_shape = [self._num_attention_heads, self._head_size]
else:
attention_bias_shape = [self._num_layers, self._num_attention_heads,
self._head_size]
self.content_attention_bias = self.add_weight(
"content_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.positional_attention_bias = self.add_weight(
"positional_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.segment_attention_bias = self.add_weight(
"segment_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
self.transformer_xl_layers = []
for i in range(self._num_layers):
self.transformer_xl_layers.append(
TransformerXLBlock(
vocab_size=self._vocab_size,
hidden_size=self._head_size * self._num_attention_heads,
num_attention_heads=self._num_attention_heads,
head_size=self._head_size,
inner_size=self._inner_size,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
norm_epsilon=1e-12,
inner_activation=self._inner_activation,
two_stream=self._two_stream,
kernel_initializer="variance_scaling",
name="layer_%d" % i))
self.output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
}
base_config = super(TransformerXL, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
relative_position_encoding,
segment_matrix=None,
segment_embedding=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements call() for the layer.
Args:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_embedding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A tuple consisting of the attention output and the list of cached memory
states.
The attention output is `content_attention` if `two_stream` is `False`,
otherwise it is `query_attention`.
"""
new_mems = []
if state is None:
state = [None] * self._num_layers
for i in range(self._num_layers):
# cache new mems
new_mems.append(
_cache_memory(content_stream, state[i],
self._memory_length, self._reuse_length))
# segment bias
if segment_matrix is None:
segment_attention_bias = None
segment_encoding = None
else:
segment_attention_bias = (self.segment_attention_bias
if self._tie_attention_biases
else self.segment_attention_bias[i])
segment_encoding = segment_embedding[i]
content_attention_bias = (self.content_attention_bias
if self._tie_attention_biases
else self.content_attention_bias[i])
positional_attention_bias = (self.positional_attention_bias
if self._tie_attention_biases
else self.positional_attention_bias[i])
transformer_xl_layer = self.transformer_xl_layers[i]
transformer_xl_output = transformer_xl_layer(
content_stream=content_stream,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state[i],
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask,
query_stream=query_stream,
target_mapping=target_mapping)
content_stream = transformer_xl_output["content_attention"]
if self._two_stream:
query_stream = transformer_xl_output["query_attention"]
else:
query_stream = None
if self._two_stream:
output_stream = query_stream
else:
output_stream = content_stream
return output_stream, new_mems
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Transformer XL."""
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import transformer_xl
def create_mock_transformer_xl_data(
batch_size,
num_heads,
head_size,
hidden_size,
seq_length,
memory_length=0,
num_predictions=2,
two_stream=False,
num_layers=1,
include_biases=True,
include_state=False,
include_mask=False,
include_segment=False):
"""Creates mock testing data.
Args:
batch_size: `int`, the batch size.
num_heads: `int`, number of attention heads.
head_size: `int`, the size of each attention head.
hidden_size: `int`, the layer's hidden size.
seq_length: `int`, Sequence length of the input.
memory_length: optional `int`, the length of the state. Defaults to 0.
num_predictions: `int`, the number of predictions used in two stream
attention.
two_stream: `bool`, whether or not to generate two stream data.
num_layers: `int`, the number of Transformer XL blocks.
include_biases: optional `bool`, whether or not to include attention biases.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
encoding_shape = (batch_size, seq_length * 2, hidden_size)
data = dict(
relative_position_encoding=tf.random.normal(shape=encoding_shape),
content_stream=tf.random.normal(
shape=(batch_size, seq_length, hidden_size)))
if include_biases:
attention_bias_shape = (num_heads, head_size)
data.update(dict(
content_attention_bias=tf.random.normal(shape=attention_bias_shape),
segment_attention_bias=tf.random.normal(shape=attention_bias_shape),
positional_attention_bias=tf.random.normal(shape=attention_bias_shape)))
if two_stream:
data.update(dict(
query_stream=tf.random.normal(
shape=(batch_size, num_predictions, hidden_size)),
target_mapping=tf.random.normal(
shape=(batch_size, num_predictions, seq_length))))
if include_state:
total_seq_length = seq_length + memory_length
if num_layers > 1:
state_shape = (num_layers, batch_size, memory_length, hidden_size)
else:
state_shape = (batch_size, memory_length, hidden_size)
data.update(dict(
state=tf.random.normal(shape=state_shape)))
else:
total_seq_length = seq_length
if include_mask:
mask_shape = (batch_size, num_heads, seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32")
data["content_attention_mask"] = mask_data
if two_stream:
data["query_attention_mask"] = mask_data
if include_segment:
# A transformer XL block takes an individual segment "encoding" from the
# entirety of the Transformer XL segment "embedding".
if num_layers > 1:
segment_encoding_shape = (num_layers, 2, num_heads, head_size)
segment_encoding_name = "segment_embedding"
else:
segment_encoding_shape = (2, num_heads, head_size)
segment_encoding_name = "segment_encoding"
segment_matrix = np.random.randint(
2, size=(batch_size, seq_length, total_seq_length))
data["segment_matrix"] = tf.math.equal(segment_matrix, 1)
data[segment_encoding_name] = tf.random.normal(shape=segment_encoding_shape)
return data
@keras_parameterized.run_all_keras_modes
class TransformerXLBlockTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
memory_length=[0, 4],
two_stream=[True, False],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_transformer_xl_block(
self,
two_stream,
memory_length,
state,
mask,
segment):
"""Tests combinations of Transformer XL block calculations."""
batch_size, num_heads, head_size, seq_length = 2, 12, 64, 8
hidden_size, num_predictions, inner_size = 24, 8, 12
data = create_mock_transformer_xl_data(
include_biases=True,
num_heads=num_heads,
head_size=head_size,
hidden_size=hidden_size,
seq_length=seq_length,
batch_size=batch_size,
memory_length=memory_length,
num_predictions=num_predictions,
two_stream=two_stream,
include_state=state,
include_mask=mask,
include_segment=segment)
test_layer = transformer_xl.TransformerXLBlock(
vocab_size=32000,
hidden_size=hidden_size,
num_attention_heads=num_heads,
head_size=head_size,
inner_size=inner_size,
dropout_rate=0.,
attention_dropout_rate=0.,
two_stream=two_stream)
output = test_layer(**data)
content_attention = output["content_attention"]
self.assertEqual(content_attention.shape,
[batch_size, seq_length, hidden_size])
if two_stream:
self.assertIn("query_attention", output)
self.assertEqual(output["query_attention"].shape,
[batch_size, num_predictions, hidden_size])
else:
self.assertNotIn("query_attention", output)
def test_get_config(self):
transformer_xl_block = transformer_xl.TransformerXLBlock(
vocab_size=32000,
head_size=64,
num_attention_heads=2,
hidden_size=10,
inner_size=50,
dropout_rate=0.,
attention_dropout_rate=0.,
two_stream=False)
transformer_xl_block_config = transformer_xl_block.get_config()
new_block = transformer_xl.TransformerXLBlock.from_config(
transformer_xl_block_config)
self.assertEqual(transformer_xl_block_config, new_block.get_config())
@keras_parameterized.run_all_keras_modes
class TransformerXLTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
two_stream=[True, False],
memory_length=[0, 4],
reuse_length=[0, 4],
tie_attention_biases=[True, False],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_transformer_xl(
self,
two_stream,
memory_length,
reuse_length,
tie_attention_biases,
state,
mask,
segment):
batch_size, num_heads, head_size, seq_length = 2, 12, 64, 8
hidden_size, num_predictions, inner_size = 24, 8, 12
num_layers = 3
data = create_mock_transformer_xl_data(
include_biases=False,
num_heads=num_heads,
head_size=head_size,
hidden_size=hidden_size,
seq_length=seq_length,
batch_size=batch_size,
memory_length=memory_length,
num_predictions=num_predictions,
two_stream=two_stream,
num_layers=num_layers,
include_state=state,
include_mask=mask,
include_segment=segment)
transformer_xl_layer = transformer_xl.TransformerXL(
vocab_size=32000,
num_layers=num_layers,
head_size=head_size,
hidden_size=hidden_size,
num_attention_heads=num_heads,
inner_size=inner_size,
dropout_rate=0.,
attention_dropout_rate=0.,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=two_stream,
tie_attention_biases=tie_attention_biases,
memory_length=memory_length,
reuse_length=reuse_length,
inner_activation="relu")
attention_output, cached_memory_states = transformer_xl_layer(**data)
if two_stream:
self.assertEqual(attention_output.shape,
[batch_size, num_predictions, hidden_size])
else:
self.assertEqual(attention_output.shape,
[batch_size, seq_length, hidden_size])
self.assertEqual(len(cached_memory_states), num_layers)
def test_get_config(self):
transformer_xl_layer = transformer_xl.TransformerXL(
vocab_size=32000,
num_layers=12,
hidden_size=36,
head_size=12,
num_attention_heads=12,
inner_size=12,
dropout_rate=0.,
attention_dropout_rate=0.,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
memory_length=0,
reuse_length=0,
inner_activation="relu")
transformer_xl_config = transformer_xl_layer.get_config()
new_transformer_xl = transformer_xl.TransformerXL.from_config(
transformer_xl_config)
self.assertEqual(transformer_xl_config, new_transformer_xl.get_config())
if __name__ == "__main__":
np.random.seed(0)
tf.random.set_seed(0)
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based transformer block layer."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""Keras-based transformer block layer."""
import functools
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,6 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Activations package definition. Subject to change."""
"""Losses contains common loss computation used in NLP (subject to change)."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Weighted sparse categorical cross-entropy losses."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""Weighted sparse categorical cross-entropy losses."""
import tensorflow as tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment