Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fb826f1f
Commit
fb826f1f
authored
Sep 15, 2022
by
James Lee-Thorp
Committed by
A. Unique TensorFlower
Sep 15, 2022
Browse files
Internal change
PiperOrigin-RevId: 474606464
parent
fbfa5038
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
378 additions
and
0 deletions
+378
-0
official/nlp/modeling/layers/mixing.py
official/nlp/modeling/layers/mixing.py
+269
-0
official/nlp/modeling/layers/mixing_test.py
official/nlp/modeling/layers/mixing_test.py
+109
-0
No files found.
official/nlp/modeling/layers/mixing.py
0 → 100644
View file @
fb826f1f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based mixing layers.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
Mixing layers can be used as drop in replacements for self-attention layers. For
interoperability with attention layers, we use the same `query` and `value` call
signature.
Note: These mixing layers currently only support encoder stacks. Decoder stacks
can be supported in the future by utilizing the `value` inputs.
"""
import
functools
from
typing
import
Callable
,
Tuple
,
Union
import
numpy
as
np
from
scipy
import
linalg
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
default_kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
2e-2
)
class
MixingLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Mixing layer base class.
This class cannot be used directly. It just specifies the API for mixing
layer subclasses. For interoperability with attention layers, we use the same
`query` and `value` call signature.
Based on the mixing layers use by FNet
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers
(https://arxiv.org/abs/2205.12399).
"""
def
__init__
(
self
,
name
:
str
=
"mixing"
,
**
kwargs
):
"""Initializes layer.
Args:
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Calls the layer.
Subclasses should return tensors of shape
<float>[batch_size, max_seq_length, hidden_dim].
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Raises:
NotImplementedError. This class should not be called directly.
"""
raise
NotImplementedError
(
"Abstract method"
)
class
FourierTransformLayer
(
MixingLayer
):
"""Fourier Transform layer.
Applies 2D Fourier Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def
__init__
(
self
,
use_fft
:
bool
=
False
,
name
:
str
=
"fourier_transform"
,
**
kwargs
):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Fourier Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
use_fft
=
use_fft
def
build
(
self
,
input_shape
:
Tuple
[
int
,
...]):
"""Picks the Fourier Transform implementation."""
self
.
fourier_transform
=
_pick_fourier_transform
(
self
.
use_fft
,
max_seq_length
=
input_shape
[
-
2
],
hidden_dim
=
input_shape
[
-
1
])
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Fourier Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del
value
# Ignored by encoder-only mixing layers
query
=
tf
.
cast
(
query
,
tf
.
complex64
)
return
tf
.
math
.
real
(
self
.
fourier_transform
(
query
))
class
HartleyTransformLayer
(
MixingLayer
):
"""Hartley Transform layer.
Applies 2D Hartley Transform over final two dimensions of `query` inputs -
typically the sequence and hidden dimensions.
"""
def
__init__
(
self
,
use_fft
:
bool
=
False
,
name
:
str
=
"hartley_transform"
,
**
kwargs
):
"""Initializes layer.
Args:
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete
Fourier Transform (DFT) matrix (False) to compute the Hartley Transform.
See _pick_fourier_transform() for recommendations on when to use FFT or
DFT.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
use_fft
=
use_fft
def
build
(
self
,
input_shape
:
Tuple
[
int
,
...]):
"""Picks the Fourier Transform implementation."""
self
.
fourier_transform
=
_pick_fourier_transform
(
self
.
use_fft
,
max_seq_length
=
input_shape
[
-
2
],
hidden_dim
=
input_shape
[
-
1
])
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Real part of discrete Hartley Transform of `query` inputs with shape
<float32>[batch_size, max_seq_length, hidden_dim].
"""
del
value
# Ignored by encoder-only mixing layers
query
=
tf
.
cast
(
query
,
tf
.
complex64
)
frequencies
=
self
.
fourier_transform
(
query
)
return
tf
.
math
.
real
(
frequencies
)
-
tf
.
math
.
imag
(
frequencies
)
class
LinearTransformLayer
(
MixingLayer
):
"""Dense, linear transformation layer.
Applies matrix multiplications over sequence and hidden dimensions.
"""
def
__init__
(
self
,
kernel_initializer
:
_Initializer
=
default_kernel_initializer
,
name
:
str
=
"linear_transform"
,
**
kwargs
):
"""Initializes layer.
Args:
kernel_initializer: Initialization scheme for kernel.
name: Name for layer.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
kernel_initializer
=
kernel_initializer
def
build
(
self
,
input_shape
:
Tuple
[
int
,
...]):
"""Creates the hidden and sequence matrix variables of the layer."""
self
.
mat_hidden
=
self
.
add_weight
(
shape
=
(
input_shape
[
-
1
],
input_shape
[
-
1
]),
initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
trainable
=
True
,
name
=
"hidden_kernel"
)
self
.
mat_seq
=
self
.
add_weight
(
shape
=
(
input_shape
[
-
2
],
input_shape
[
-
2
]),
initializer
=
tf_utils
.
clone_initializer
(
self
.
kernel_initializer
),
trainable
=
True
,
name
=
"seq_kernel"
)
def
call
(
self
,
query
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
"""Applies layer to `query`.
Args:
query: Batch of input embeddings, typically of shape <float>[batch_size,
max_seq_length, hidden_dim].
value: Unused. Included to match attention layer API.
**kwargs: Optional arguments to catch unused attention keyword arguments.
Returns:
Linearly transformed `query` inputs with shape
<float>[batch_size, max_seq_length, hidden_dim].
"""
del
value
# Ignored by encoder-only mixing layers
return
tf
.
einsum
(
"bij,jk,ni->bnk"
,
query
,
self
.
mat_hidden
,
self
.
mat_seq
)
def
_pick_fourier_transform
(
use_fft
:
bool
,
max_seq_length
:
int
,
hidden_dim
:
int
)
->
Callable
[[
tf
.
Tensor
],
tf
.
Tensor
]:
"""Returns FFT or DFT Fourier Transform implementation.
On TPUs, we recommend using the Discrete Fourier Transform (DFT) matrix
(use_fft=False), except for very long sequence lengths. On GPUs and CPUs, the
Fast Fourier Transform (use_fft=True) is generally optimal for all sequence
lengths.
Note: When using the FFT it is recommended to use a sequence length that is a
power of 2.
Args:
use_fft: If True, return FFT. Otherwise, return DFT matrix.
max_seq_length: Maximum sequence length of inputs. Only used if
use_fft=False.
hidden_dim: Size of hidden dimension of inputs. Only used if use_fft=False.
Returns:
Fourier Transform.
"""
if
use_fft
:
return
tf
.
signal
.
fft2d
else
:
dft_mat_seq
=
linalg
.
dft
(
max_seq_length
).
astype
(
np
.
complex64
)
dft_mat_hidden
=
linalg
.
dft
(
hidden_dim
).
astype
(
np
.
complex64
)
def
two_dim_matmul
(
x
:
tf
.
Tensor
,
matrix_dim_one
:
tf
.
Tensor
,
matrix_dim_two
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies 2D matrix multiplication to input tensors of rank >= 2."""
return
tf
.
einsum
(
"...ij,jk,ni->...nk"
,
tf
.
cast
(
x
,
tf
.
complex64
),
matrix_dim_two
,
matrix_dim_one
)
return
functools
.
partial
(
two_dim_matmul
,
matrix_dim_one
=
tf
.
convert_to_tensor
(
dft_mat_seq
),
matrix_dim_two
=
tf
.
convert_to_tensor
(
dft_mat_hidden
))
official/nlp/modeling/layers/mixing_test.py
0 → 100644
View file @
fb826f1f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mixing.py."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
mixing
class
MixingTest
(
tf
.
test
.
TestCase
):
def
test_base_mixing_layer
(
self
):
inputs
=
tf
.
random
.
uniform
((
3
,
8
,
16
),
minval
=
0
,
maxval
=
10
,
dtype
=
tf
.
float32
)
with
self
.
assertRaisesRegex
(
NotImplementedError
,
"Abstract method"
):
_
=
mixing
.
MixingLayer
()(
query
=
inputs
,
value
=
inputs
)
def
test_fourier_layer
(
self
):
batch_size
=
4
max_seq_length
=
8
hidden_dim
=
16
inputs
=
tf
.
random
.
uniform
((
batch_size
,
max_seq_length
,
hidden_dim
),
minval
=
0
,
maxval
=
10
,
dtype
=
tf
.
float32
)
outputs
=
mixing
.
FourierTransformLayer
(
use_fft
=
True
)(
query
=
inputs
,
value
=
inputs
)
self
.
assertEqual
(
outputs
.
shape
,
(
batch_size
,
max_seq_length
,
hidden_dim
))
def
test_hartley_layer
(
self
):
batch_size
=
3
max_seq_length
=
16
hidden_dim
=
4
inputs
=
tf
.
random
.
uniform
((
batch_size
,
max_seq_length
,
hidden_dim
),
minval
=
0
,
maxval
=
12
,
dtype
=
tf
.
float32
)
outputs
=
mixing
.
HartleyTransformLayer
(
use_fft
=
True
)(
query
=
inputs
,
value
=
inputs
)
self
.
assertEqual
(
outputs
.
shape
,
(
batch_size
,
max_seq_length
,
hidden_dim
))
def
test_linear_mixing_layer
(
self
):
batch_size
=
2
max_seq_length
=
4
hidden_dim
=
3
inputs
=
tf
.
ones
((
batch_size
,
max_seq_length
,
hidden_dim
),
dtype
=
tf
.
float32
)
outputs
=
mixing
.
LinearTransformLayer
(
kernel_initializer
=
tf
.
keras
.
initializers
.
Ones
())(
query
=
inputs
,
value
=
inputs
)
# hidden_dim * (max_seq_length * 1) = 12.
expected_outputs
=
[
[
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
],
[
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
],
],
]
np
.
testing
.
assert_allclose
(
outputs
,
expected_outputs
,
rtol
=
1e-6
,
atol
=
1e-6
)
def
test_pick_fourier_transform
(
self
):
# Ensure we don't hit an edge case which exceeds the fixed numerical error.
tf
.
random
.
set_seed
(
1
)
np
.
random
.
seed
(
1
)
batch_size
=
3
max_seq_length
=
4
hidden_dim
=
8
fft
=
mixing
.
_pick_fourier_transform
(
use_fft
=
True
,
max_seq_length
=
max_seq_length
,
hidden_dim
=
hidden_dim
)
dft_matmul
=
mixing
.
_pick_fourier_transform
(
use_fft
=
False
,
max_seq_length
=
max_seq_length
,
hidden_dim
=
hidden_dim
)
inputs
=
tf
.
random
.
uniform
([
batch_size
,
max_seq_length
,
hidden_dim
])
inputs
=
tf
.
cast
(
inputs
,
tf
.
complex64
)
np
.
testing
.
assert_allclose
(
fft
(
inputs
),
dft_matmul
(
inputs
),
rtol
=
1e-6
,
atol
=
1e-6
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment