Commit 193a6088 authored by James Lee-Thorp's avatar James Lee-Thorp Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 474606464
parent 150975e8
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based mixing layers.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
Mixing layers can be used as drop in replacements for self-attention layers. For
interoperability with attention layers, we use the same `query` and `value` call
signature.
Note: These mixing layers currently only support encoder stacks. Decoder stacks
can be supported in the future by utilizing the `value` inputs.
"""
import functools
from typing import Callable, Tuple, Union
import numpy as np
from scipy import linalg
import tensorflow as tf
from official.modeling import tf_utils
_Initializer = Union[str, tf.keras.initializers.Initializer]
default_kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=2e-2)
class MixingLayer(tf.keras.layers.Layer):
"""Mixing layer base class.
This class cannot be used directly. It just specifies the API for mixing
layer subclasses. For interoperability with attention layers, we use the same
`query` and `value` call signature.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
"""
def __init__(self, name: str = "mixing", **kwargs):
"""Initializes layer.
Args:
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Calls the layer.
Subclasses should return tensors of shape
<float>[batch_size, max_seq_length, hidden_dim].
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Raises:
NotImplementedError. This class should not be called directly.
"""
raise NotImplementedError("Abstract method")
class FourierTransformLayer(MixingLayer):
"""Fourier Transform layer.
Applies 2D Fourier Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def __init__(self,
use_fft: bool = False,
name: str = "fourier_transform",
**kwargs):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Fourier Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
self.use_fft = use_fft
def build(self, input_shape: Tuple[int, ...]):
"""Picks the Fourier Transform implementation."""
self.fourier_transform = _pick_fourier_transform(
self.use_fft,
max_seq_length=input_shape[-2],
hidden_dim=input_shape[-1])
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Fourier Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del value # Ignored by encoder-only mixing layers
query = tf.cast(query, tf.complex64)
return tf.math.real(self.fourier_transform(query))
class HartleyTransformLayer(MixingLayer):
"""Hartley Transform layer.
Applies 2D Hartley Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def __init__(self,
use_fft: bool = False,
name: str = "hartley_transform",
**kwargs):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Hartley Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
self.use_fft = use_fft
def build(self, input_shape: Tuple[int, ...]):
"""Picks the Fourier Transform implementation."""
self.fourier_transform = _pick_fourier_transform(
self.use_fft,
max_seq_length=input_shape[-2],
hidden_dim=input_shape[-1])
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Hartley Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del value # Ignored by encoder-only mixing layers
query = tf.cast(query, tf.complex64)
frequencies = self.fourier_transform(query)
return tf.math.real(frequencies) - tf.math.imag(frequencies)
class LinearTransformLayer(MixingLayer):
"""Dense, linear transformation layer.
Applies matrix multiplications over sequence and hidden dimensions.
"""
def __init__(self,
kernel_initializer: _Initializer = default_kernel_initializer,
name: str = "linear_transform",
**kwargs):
"""Initializes layer.
Args:
kernel_initializer: Initialization scheme for kernel.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super().__init__(name=name, **kwargs)
self.kernel_initializer = kernel_initializer
def build(self, input_shape: Tuple[int, ...]):
"""Creates the hidden and sequence matrix variables of the layer."""
self.mat_hidden = self.add_weight(
shape=(input_shape[-1], input_shape[-1]),
initializer=tf_utils.clone_initializer(self.kernel_initializer),
trainable=True,
name="hidden_kernel")
self.mat_seq = self.add_weight(
shape=(input_shape[-2], input_shape[-2]),
initializer=tf_utils.clone_initializer(self.kernel_initializer),
trainable=True,
name="seq_kernel")
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Linearly transformed `query` inputs with shape
<float>[batch_size, max_seq_length, hidden_dim].
"""
del value # Ignored by encoder-only mixing layers
return tf.einsum("bij,jk,ni->bnk", query, self.mat_hidden, self.mat_seq)
def _pick_fourier_transform(
use_fft: bool, max_seq_length: int,
hidden_dim: int) -> Callable[[tf.Tensor], tf.Tensor]:
"""Returns FFT or DFT Fourier Transform implementation.
On TPUs, we recommend using the Discrete Fourier Transform (DFT) matrix
(use_fft=False), except for very long sequence lengths. On GPUs and CPUs, the
Fast Fourier Transform (use_fft=True) is generally optimal for all sequence
lengths.
Note: When using the FFT it is recommended to use a sequence length that is a
power of 2.
Args:
use_fft: If True, return FFT. Otherwise, return DFT matrix.
max_seq_length: Maximum sequence length of inputs. Only used if
use_fft=False.
hidden_dim: Size of hidden dimension of inputs. Only used if use_fft=False.
Returns:
Fourier Transform.
"""
if use_fft:
return tf.signal.fft2d
else:
dft_mat_seq = linalg.dft(max_seq_length).astype(np.complex64)
dft_mat_hidden = linalg.dft(hidden_dim).astype(np.complex64)
def two_dim_matmul(x: tf.Tensor, matrix_dim_one: tf.Tensor,
matrix_dim_two: tf.Tensor) -> tf.Tensor:
"""Applies 2D matrix multiplication to input tensors of rank >= 2."""
return tf.einsum("...ij,jk,ni->...nk", tf.cast(x, tf.complex64),
matrix_dim_two, matrix_dim_one)
return functools.partial(
two_dim_matmul,
matrix_dim_one=tf.convert_to_tensor(dft_mat_seq),
matrix_dim_two=tf.convert_to_tensor(dft_mat_hidden))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mixing.py."""
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import mixing
class MixingTest(tf.test.TestCase):
def test_base_mixing_layer(self):
inputs = tf.random.uniform((3, 8, 16),
minval=0,
maxval=10,
dtype=tf.float32)
with self.assertRaisesRegex(NotImplementedError, "Abstract method"):
_ = mixing.MixingLayer()(query=inputs, value=inputs)
def test_fourier_layer(self):
batch_size = 4
max_seq_length = 8
hidden_dim = 16
inputs = tf.random.uniform((batch_size, max_seq_length, hidden_dim),
minval=0,
maxval=10,
dtype=tf.float32)
outputs = mixing.FourierTransformLayer(use_fft=True)(
query=inputs, value=inputs)
self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_hartley_layer(self):
batch_size = 3
max_seq_length = 16
hidden_dim = 4
inputs = tf.random.uniform((batch_size, max_seq_length, hidden_dim),
minval=0,
maxval=12,
dtype=tf.float32)
outputs = mixing.HartleyTransformLayer(use_fft=True)(
query=inputs, value=inputs)
self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_linear_mixing_layer(self):
batch_size = 2
max_seq_length = 4
hidden_dim = 3
inputs = tf.ones((batch_size, max_seq_length, hidden_dim), dtype=tf.float32)
outputs = mixing.LinearTransformLayer(
kernel_initializer=tf.keras.initializers.Ones())(
query=inputs, value=inputs)
# hidden_dim * (max_seq_length * 1) = 12.
expected_outputs = [
[
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
],
[
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.],
],
]
np.testing.assert_allclose(outputs, expected_outputs, rtol=1e-6, atol=1e-6)
def test_pick_fourier_transform(self):
# Ensure we don't hit an edge case which exceeds the fixed numerical error.
tf.random.set_seed(1)
np.random.seed(1)
batch_size = 3
max_seq_length = 4
hidden_dim = 8
fft = mixing._pick_fourier_transform(
use_fft=True, max_seq_length=max_seq_length, hidden_dim=hidden_dim)
dft_matmul = mixing._pick_fourier_transform(
use_fft=False, max_seq_length=max_seq_length, hidden_dim=hidden_dim)
inputs = tf.random.uniform([batch_size, max_seq_length, hidden_dim])
inputs = tf.cast(inputs, tf.complex64)
np.testing.assert_allclose(
fft(inputs), dft_matmul(inputs), rtol=1e-6, atol=1e-6)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment