Commit c7e31961 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332476566
parent ba206271
...@@ -27,5 +27,6 @@ from official.nlp.modeling.layers.position_embedding import RelativePositionEmbe ...@@ -27,5 +27,6 @@ from official.nlp.modeling.layers.position_embedding import RelativePositionEmbe
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
from official.nlp.modeling.layers.transformer import * from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ExpandCondense tensor network layer used in TN-BERT."""
# pylint: disable=g-classes-have-attributes
from typing import List, Optional, Text, Any, Dict
import tensorflow as tf
Layer = tf.keras.layers.Layer
activations = tf.keras.activations
initializers = tf.keras.initializers
@tf.keras.utils.register_keras_serializable(package='Text')
class TNExpandCondense(Layer):
"""A TPU-optimized TensorNetwork layer.
Designed for use in models that currently use Dense layers to achieve
up projection followed by down projection.
This layer is a TPU-optimized combination of 3 operations:
Expand, Apply Activation, and Condense. The layer projects up from
`input_shape[-1]` to `input_shape[-1] * proj_multiplier`, applies
`self.activation`, and then condenses back to `input_shape[-1]`.
Note the input shape and output shape will be identical.
Arguments:
proj_multiplier: Positive integer, multiple of input_shape[-1] to project
up to. Must be one of [2, 4, 6, 8].
use_bias: Boolean, whether the layer uses a bias vector.
activation: Activation function to use between Expand and Condense. If you
don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
kernel_initializer: Initializer for the weight matrices.
bias_initializer: Initializer for the bias vector.
Input shape:
N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
Output shape:
N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
"""
def __init__(self,
proj_multiplier: int,
use_bias: Optional[bool] = True,
activation: Optional[Text] = 'relu',
kernel_initializer: Optional[Text] = 'glorot_uniform',
bias_initializer: Optional[Text] = 'zeros',
**kwargs) -> None:
# Allow specification of input_dim instead of input_shape,
# for compatability with Keras layers that support this
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(TNExpandCondense, self).__init__(**kwargs)
assert proj_multiplier in [
2, 4, 6, 8, 10, 12
], 'proj_multiplier needs to be one of [2, 4, 6, 8, 10, 12]'
self.proj_multiplier = proj_multiplier
self.use_bias = use_bias
self.activation = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
def build(self, input_shape: List[int]) -> None:
# Disable the attribute-defined-outside-init violations in this function
# pylint: disable=attribute-defined-outside-init
if input_shape[-1] is None:
raise ValueError(
'The last dimension of the inputs to `TNExpandCondense` '
'should be defined. Found `None`.')
super(TNExpandCondense, self).build(input_shape)
self.proj_size = self.proj_multiplier * input_shape[-1]
assert (self.proj_size // input_shape[-1]) * input_shape[
-1] == self.proj_size, (f'{self.proj_size} / {input_shape[-1]} must be '
f'round')
assert (input_shape[-1] // 128
) * 128 == input_shape[-1], f'{input_shape[-1]} / 128 must be round'
self.w1 = self.add_weight(
name='w1',
shape=(input_shape[-1], input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
self.w2 = self.add_weight(
name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True,
initializer=self.kernel_initializer)
self.w3 = self.add_weight(
name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True,
initializer=self.kernel_initializer)
self.w4 = self.add_weight(
name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
if self.use_bias:
self.bias = self.add_weight(
name='b',
shape=(input_shape[-1] // 128, 1,
128 * (self.proj_size // input_shape[-1])),
trainable=True,
initializer=self.bias_initializer)
else:
self.bias = None
def call(self, inputs: tf.Tensor, **kwargs):
orig_shape = tf.shape(inputs)
input_dim = inputs.shape[-1]
tmp = tf.reshape(inputs, (-1, input_dim))
# Shape is (BatchSeq, input_dim)
# Expansion network
tmp = tf.einsum('ab,Qb->aQ', self.w1, tmp)
# Note: Letter Q will always represent the BatchSeq axis.
tmp = tf.reshape(tmp, (input_dim // 128, 128, -1))
tmp = tf.einsum('abQ,bd->aQd', tmp, self.w2)
# Apply activation and then Condense
tmp = self.activation(tmp + self.bias)
tmp = tf.einsum('aQd,db->aQb', tmp, self.w3)
tmp = tf.einsum('aQb,abd->Qd', tmp, self.w4)
out = tf.reshape(tmp, orig_shape)
return out
def compute_output_shape(self, input_shape: List[int]) -> List[int]:
return input_shape
def get_config(self) -> Dict[Any, Any]:
"""Returns the config of the layer.
The same layer can be reinstantiated later
(without its trained weights) from this configuration.
Returns:
Python dictionary containing the configuration of the layer.
"""
config = {}
# Include the layer-specific arguments
args = ['proj_multiplier', 'use_bias']
for arg in args:
config[arg] = getattr(self, arg)
# Serialize the activation
config['activation'] = activations.serialize(getattr(self, 'activation'))
# Serialize the initializers
decomp_initializers = ['kernel_initializer', 'bias_initializer']
for initializer_arg in decomp_initializers:
config[initializer_arg] = initializers.serialize(
getattr(self, initializer_arg))
# Get base config
base_config = super(TNExpandCondense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ExpandCondense tensor network layer."""
import os
import shutil
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.testing_utils import layer_test
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for ExpandCondense TN layer.
"""
def setUp(self):
super(TNLayerTest, self).setUp()
self.labels = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))), axis=0)
def _build_model(self, data, proj_multiple=2):
model = tf.keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
return model
@parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple):
data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32)
layer_test(
TNExpandCondense,
kwargs={
'proj_multiplier': proj_multiple,
'input_shape': data.shape
},
input_shape=data.shape,
input_data=data,
expected_output_shape=(None, data.shape[-1]),
expected_output_dtype=data.dtype)
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
tf.random.set_seed(0)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
history = model.fit(data, self.labels, epochs=5, batch_size=32)
# Check that loss decreases and accuracy increases
self.assertGreater(history.history['loss'][0], history.history['loss'][-1])
self.assertLess(
history.history['accuracy'][0], history.history['accuracy'][-1])
@parameterized.parameters((768, 6), (1024, 2))
def test_weights_change(self, input_dim, proj_multiple):
tf.random.set_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
before = model.get_weights()
model.fit(data, self.labels, epochs=5, batch_size=32)
after = model.get_weights()
# Make sure every layer's weights changed
for i, _ in enumerate(before):
self.assertTrue((after[i] != before[i]).any())
@parameterized.parameters((768, 6), (1024, 2))
def test_output_shape(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
input_shape = data.shape
actual_output_shape = model(data).shape
expected_output_shape = model.compute_output_shape(input_shape)
self.assertEqual(expected_output_shape, actual_output_shape)
@parameterized.parameters((768, 6), (1024, 2))
def test_expandcondense_num_parameters(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
proj_size = proj_multiple * data.shape[-1]
model = tf.keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
w1_params = data.shape[-1]**2
w2_params = 128 * 128 * (proj_size // data.shape[-1])
w3_params = 128 * 128 * (proj_size // data.shape[-1])
w4_params = (data.shape[-1] // 128) * 128 * data.shape[-1]
bias_params = ((data.shape[-1] // 128) * 128 *
(proj_size // data.shape[-1]))
expected_num_parameters = (w1_params + w2_params + w3_params +
w4_params) + bias_params
self.assertEqual(expected_num_parameters, model.count_params())
@parameterized.parameters((912, 6), (200, 2))
def test_incorrect_sizes(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
with self.assertRaises(AssertionError):
model = self._build_model(data, proj_multiple)
model.compile(optimizer='adam', loss='binary_crossentropy')
@parameterized.parameters((768, 6), (1024, 2))
def test_config(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
expected_num_parameters = model.layers[0].count_params()
# Serialize model and use config to create new layer
model_config = model.get_config()
layer_config = model_config['layers'][1]['config']
new_model = TNExpandCondense.from_config(layer_config)
# Build the layer so we can count params below
new_model.build(layer_config['batch_input_shape'])
# Check that original layer had same num params as layer built from config
self.assertEqual(expected_num_parameters, new_model.count_params())
@parameterized.parameters((768, 6), (1024, 2))
def test_model_save(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
model.fit(data, self.labels, epochs=5, batch_size=32)
for save_path in ['/test_model', '/test_model.h5']:
# Save model to a SavedModel folder or h5 file, then load model
save_path = os.environ['TEST_UNDECLARED_OUTPUTS_DIR'] + save_path
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
# Clean up SavedModel folder
if os.path.isdir(save_path):
shutil.rmtree(save_path)
# Clean up h5 file
if os.path.exists(save_path):
os.remove(save_path)
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TN-BERT TNTransformerExpandCondense employing Expand-Condense layer instead of Dense."""
# pylint: disable=g-classes-have-attributes
# Import libraries
import gin
import tensorflow as tf
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TNTransformerExpandCondense(tf.keras.layers.Layer):
"""Transformer layer using tensor network Expand-Condense layer.
This layer implements the Transformer from transformer.py, with a single
tensor network layer replacing the usual intermediate and output Dense
layers.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set to False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(TNTransformerExpandCondense, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError(
"TNTransformerExpandCondense expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError(
"When passing a mask tensor to TNTransformerExpandCondense, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
# Substitute Dense layers with a single Expand-Condense layer.
self._output_dense = TNExpandCondense(
4,
use_bias=True,
activation=self._intermediate_activation,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(TNTransformerExpandCondense, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super(TNTransformerExpandCondense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
layer_output = self._output_dense(attention_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TN-BERT transformer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(('tn', TNTransformerExpandCondense))
class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
batch_size = 6
input_data = 16 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (16 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
sequence_length = 21
width = 256
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=16,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 256
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment