Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c0525d49
Commit
c0525d49
authored
Sep 23, 2022
by
James Lee-Thorp
Committed by
A. Unique TensorFlower
Sep 23, 2022
Browse files
Internal change
PiperOrigin-RevId: 476469203
parent
91a0e443
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
480 additions
and
0 deletions
+480
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+4
-0
official/nlp/modeling/layers/mixing.py
official/nlp/modeling/layers/mixing.py
+14
-0
official/nlp/modeling/networks/README.md
official/nlp/modeling/networks/README.md
+5
-0
official/nlp/modeling/networks/__init__.py
official/nlp/modeling/networks/__init__.py
+1
-0
official/nlp/modeling/networks/fnet.py
official/nlp/modeling/networks/fnet.py
+339
-0
official/nlp/modeling/networks/fnet_test.py
official/nlp/modeling/networks/fnet_test.py
+117
-0
No files found.
official/nlp/modeling/layers/__init__.py
View file @
c0525d49
...
@@ -30,6 +30,10 @@ from official.nlp.modeling.layers.kernel_attention import KernelMask
...
@@ -30,6 +30,10 @@ from official.nlp.modeling.layers.kernel_attention import KernelMask
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
from
official.nlp.modeling.layers.mixing
import
FourierTransformLayer
from
official.nlp.modeling.layers.mixing
import
HartleyTransformLayer
from
official.nlp.modeling.layers.mixing
import
LinearTransformLayer
from
official.nlp.modeling.layers.mixing
import
MixingMechanism
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertEmbedding
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertEmbedding
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertMaskedLM
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertMaskedLM
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertTransformer
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertTransformer
...
...
official/nlp/modeling/layers/mixing.py
View file @
c0525d49
...
@@ -26,6 +26,7 @@ Note: These mixing layers currently only support encoder stacks. Decoder stacks
...
@@ -26,6 +26,7 @@ Note: These mixing layers currently only support encoder stacks. Decoder stacks
can be supported in the future by utilizing the `value` inputs.
can be supported in the future by utilizing the `value` inputs.
"""
"""
import
enum
import
functools
import
functools
from
typing
import
Callable
,
Tuple
,
Union
from
typing
import
Callable
,
Tuple
,
Union
...
@@ -40,6 +41,19 @@ _Initializer = Union[str, tf.keras.initializers.Initializer]
...
@@ -40,6 +41,19 @@ _Initializer = Union[str, tf.keras.initializers.Initializer]
default_kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
2e-2
)
default_kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
2e-2
)
class
MixingMechanism
(
enum
.
Enum
):
"""Determines the type of mixing layer.
Possible options:
FOURIER: Fourier Transform mixing.
LINEAR: Mixing using dense matrix multiplications with learnable weights.
HARTLEY: Hartley Transform mixing.
"""
FOURIER
=
"fourier"
HARTLEY
=
"hartley"
LINEAR
=
"linear"
class
MixingLayer
(
tf
.
keras
.
layers
.
Layer
):
class
MixingLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Mixing layer base class.
"""Mixing layer base class.
...
...
official/nlp/modeling/networks/README.md
View file @
c0525d49
...
@@ -37,3 +37,8 @@ Generalized Autoregressive Pretraining for Language Understanding"
...
@@ -37,3 +37,8 @@ Generalized Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). It includes embedding lookups,
(https://arxiv.org/abs/1906.08237). It includes embedding lookups,
relative position encodings, mask computations, segment matrix computations and
relative position encodings, mask computations, segment matrix computations and
Transformer XL layers using one or two stream relative self-attention.
Transformer XL layers using one or two stream relative self-attention.
*
[
`FNet`
](
fnet.py
)
implements the encoder model from
[
"FNet: Mixing Tokens with
Fourier Transforms"
](
https://aclanthology.org/2022.naacl-main.319/
)
. FNet has
the same structure as a Transformer encoder, except that all or most of the
self-attention sublayers are replaced with Fourier sublayers.
official/nlp/modeling/networks/__init__.py
View file @
c0525d49
...
@@ -23,6 +23,7 @@ from official.nlp.modeling.networks.bert_encoder import BertEncoder
...
@@ -23,6 +23,7 @@ from official.nlp.modeling.networks.bert_encoder import BertEncoder
from
official.nlp.modeling.networks.bert_encoder
import
BertEncoderV2
from
official.nlp.modeling.networks.bert_encoder
import
BertEncoderV2
from
official.nlp.modeling.networks.classification
import
Classification
from
official.nlp.modeling.networks.classification
import
Classification
from
official.nlp.modeling.networks.encoder_scaffold
import
EncoderScaffold
from
official.nlp.modeling.networks.encoder_scaffold
import
EncoderScaffold
from
official.nlp.modeling.networks.fnet
import
FNet
from
official.nlp.modeling.networks.funnel_transformer
import
FunnelTransformerEncoder
from
official.nlp.modeling.networks.funnel_transformer
import
FunnelTransformerEncoder
from
official.nlp.modeling.networks.mobile_bert_encoder
import
MobileBERTEncoder
from
official.nlp.modeling.networks.mobile_bert_encoder
import
MobileBERTEncoder
from
official.nlp.modeling.networks.packed_sequence_embedding
import
PackedSequenceEmbedding
from
official.nlp.modeling.networks.packed_sequence_embedding
import
PackedSequenceEmbedding
...
...
official/nlp/modeling/networks/fnet.py
0 → 100644
View file @
c0525d49
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FNet encoder network.
Based on ["FNet: Mixing Tokens with Fourier Transforms"]
(https://aclanthology.org/2022.naacl-main.319/).
"""
# pylint: disable=g-classes-have-attributes
from
typing
import
Any
,
Callable
,
Optional
,
Sequence
,
Union
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
_Activation
=
Union
[
str
,
Callable
[...,
Any
]]
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
class
FNet
(
tf
.
keras
.
layers
.
Layer
):
"""FNet encoder network.
Based on ["FNet: Mixing Tokens with Fourier Transforms"]
(https://aclanthology.org/2022.naacl-main.319/). FNet is an efficient
Transformer-like encoder network that replaces self-attention sublayers with
Fourier sublayers.
This implementation defaults to the canonical FNet Base model, but the network
also supports more general mixing models (e.g. 'Linear', 'HNet') and hybrid
models (e.g. 'FNet-Hybrid') models that use both mixing and self-attention
layers.
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
mixing_mechanism: Type of mixing mechanism used in place of self-attention
layers. Defaults to FNet ('Fourier') mixing.
use_fft: Only used for spectral mixing mechanims. Determines whether to use
Fast Fourier Transform (True) or the Discrete Fourier Transform (DFT)
matrix (False; default) to compute the Fourier Transform. See
layers.FourierTransformLayer or layers.HartleyTransformLayer for advice.
attention_layers: Specifies which layers, if any, should be attention layers
in the encoder. The remaining [0, num_layers) setminus attention_layers
will use the specified `mixing_mechanism`. If using attention layers, a
good rule of thumb is to place them in the final few layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initializer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
with_dense_inputs: Whether to accept dense embeddings as the input.
"""
def
__init__
(
self
,
vocab_size
:
int
,
hidden_size
:
int
=
768
,
num_layers
:
int
=
12
,
mixing_mechanism
:
layers
.
MixingMechanism
=
layers
.
MixingMechanism
.
FOURIER
,
use_fft
:
bool
=
False
,
attention_layers
:
Sequence
[
int
]
=
(),
num_attention_heads
:
int
=
12
,
max_sequence_length
:
int
=
512
,
type_vocab_size
:
int
=
16
,
inner_dim
:
int
=
3072
,
inner_activation
:
_Activation
=
_approx_gelu
,
output_dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
initializer
:
_Initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
output_range
:
Optional
[
int
]
=
None
,
embedding_width
:
Optional
[
int
]
=
None
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
norm_first
:
bool
=
False
,
with_dense_inputs
:
bool
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
self
.
_config
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'mixing_mechanism'
:
mixing_mechanism
,
'use_fft'
:
use_fft
,
'attention_layers'
:
attention_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'inner_dim'
:
inner_dim
,
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'output_dropout'
:
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
'with_dense_inputs'
:
with_dense_inputs
,
}
if
embedding_layer
is
None
:
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'word_embeddings'
)
else
:
self
.
_embedding_layer
=
embedding_layer
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
tf_utils
.
clone_initializer
(
initializer
),
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
tf_utils
.
clone_initializer
(
initializer
),
use_one_hot
=
True
,
name
=
'type_embeddings'
)
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
self
.
_embedding_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
output_dropout
,
name
=
'embedding_dropout'
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self
.
_embedding_projection
=
None
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'embedding_projection'
)
self
.
_transformer_layers
=
[]
for
layer
in
range
(
num_layers
):
if
layer
in
attention_layers
:
mixing_layer
=
layers
.
MultiHeadAttention
(
num_heads
=
num_attention_heads
,
key_dim
=
int
(
hidden_size
//
num_attention_heads
),
dropout
=
attention_dropout
,
use_bias
=
True
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'self_attention'
,
)
else
:
mixing_layer
=
self
.
_init_mixing_sublayer
(
layer
)
block
=
layers
.
TransformerScaffold
(
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
attention_cls
=
mixing_layer
,
feedforward_cls
=
None
,
# Fallback to default FeedForward class
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
output_range
=
output_range
if
layer
==
num_layers
-
1
else
None
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'transformer/layer_%d'
%
layer
)
self
.
_transformer_layers
.
append
(
block
)
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
name
=
'self_attention_mask'
)
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
hidden_size
,
activation
=
'tanh'
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'pooler_transform'
)
if
with_dense_inputs
:
self
.
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
dense_inputs
=
tf
.
keras
.
Input
(
shape
=
(
None
,
embedding_width
),
dtype
=
tf
.
float32
),
dense_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
dense_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
)
else
:
self
.
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
))
def
call
(
self
,
inputs
):
word_embeddings
=
None
if
isinstance
(
inputs
,
dict
):
word_ids
=
inputs
.
get
(
'input_word_ids'
)
mask
=
inputs
.
get
(
'input_mask'
)
type_ids
=
inputs
.
get
(
'input_type_ids'
)
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
dense_inputs
=
inputs
.
get
(
'dense_inputs'
,
None
)
dense_mask
=
inputs
.
get
(
'dense_mask'
,
None
)
dense_type_ids
=
inputs
.
get
(
'dense_type_ids'
,
None
)
else
:
raise
ValueError
(
'Unexpected inputs type (%s) to %s.'
%
(
type
(
inputs
),
self
.
__class__
))
if
word_embeddings
is
None
:
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
if
dense_inputs
is
not
None
:
# Concat the dense embeddings at sequence end.
word_embeddings
=
tf
.
concat
([
word_embeddings
,
dense_inputs
],
axis
=
1
)
type_ids
=
tf
.
concat
([
type_ids
,
dense_type_ids
],
axis
=
1
)
mask
=
tf
.
concat
([
mask
,
dense_mask
],
axis
=
1
)
# Absolute position embeddings.
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
word_embeddings
+
position_embeddings
+
type_embeddings
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
self
.
_embedding_dropout
(
embeddings
)
if
self
.
_embedding_projection
is
not
None
:
embeddings
=
self
.
_embedding_projection
(
embeddings
)
attention_mask
=
self
.
_attention_mask_layer
(
embeddings
,
mask
)
encoder_outputs
=
[]
x
=
embeddings
for
layer
in
self
.
_transformer_layers
:
x
=
layer
([
x
,
attention_mask
])
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
output
=
dict
(
sequence_output
=
encoder_outputs
[
-
1
],
pooled_output
=
pooled_output
,
encoder_outputs
=
encoder_outputs
)
return
output
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
dict
(
self
.
_config
)
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'embedding_layer'
in
config
and
config
[
'embedding_layer'
]
is
not
None
:
warn_string
=
(
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.'
)
print
(
'WARNING: '
+
warn_string
)
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
def
_init_mixing_sublayer
(
self
,
layer
:
int
):
"""Initializes config-dependent mixing sublayer."""
if
self
.
_config
[
'mixing_mechanism'
]
==
layers
.
MixingMechanism
.
FOURIER
:
mixing_sublayer
=
layers
.
FourierTransformLayer
(
use_fft
=
self
.
_config
[
'use_fft'
],
name
=
'fourier_transform'
)
elif
self
.
_config
[
'mixing_mechanism'
]
==
layers
.
MixingMechanism
.
HARTLEY
:
mixing_sublayer
=
layers
.
HartleyTransformLayer
(
use_fft
=
self
.
_config
[
'use_fft'
],
name
=
'hartley_transform'
)
elif
self
.
_config
[
'mixing_mechanism'
]
==
layers
.
MixingMechanism
.
LINEAR
:
mixing_sublayer
=
layers
.
LinearTransformLayer
(
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_config
[
'initializer'
]),
name
=
'linear_transform'
)
else
:
raise
ValueError
(
'Unsupported mixing mechanism: %s'
%
self
.
_config
[
'mixing_mechanism'
])
return
mixing_sublayer
official/nlp/modeling/networks/fnet_test.py
0 → 100644
View file @
c0525d49
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for FNet encoder network."""
from
typing
import
Sequence
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.networks
import
fnet
class
FNetTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
tearDown
(
self
):
super
(
FNetTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
@
parameterized
.
named_parameters
(
(
"fnet"
,
layers
.
MixingMechanism
.
FOURIER
,
()),
(
"fnet_hybrid"
,
layers
.
MixingMechanism
.
FOURIER
,
(
1
,
2
)),
(
"hnet"
,
layers
.
MixingMechanism
.
HARTLEY
,
()),
(
"hnet_hybrid"
,
layers
.
MixingMechanism
.
HARTLEY
,
(
1
,
2
)),
(
"linear"
,
layers
.
MixingMechanism
.
LINEAR
,
()),
(
"linear_hybrid"
,
layers
.
MixingMechanism
.
LINEAR
,
(
0
,)),
(
"bert"
,
layers
.
MixingMechanism
.
FOURIER
,
(
0
,
1
,
2
)),
)
def
test_network
(
self
,
mixing_mechanism
:
layers
.
MixingMechanism
,
attention_layers
:
Sequence
[
int
]):
num_layers
=
3
hidden_size
=
32
sequence_length
=
21
test_network
=
fnet
.
FNet
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
num_layers
,
mixing_mechanism
=
mixing_mechanism
,
attention_layers
=
attention_layers
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
dict_outputs
=
test_network
(
dict
(
input_word_ids
=
word_ids
,
input_mask
=
mask
,
input_type_ids
=
type_ids
))
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
self
.
assertIsInstance
(
test_network
.
transformer_layers
,
list
)
self
.
assertLen
(
test_network
.
transformer_layers
,
3
)
self
.
assertIsInstance
(
test_network
.
pooler_layer
,
tf
.
keras
.
layers
.
Dense
)
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
# The default output dtype is float32.
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
pooled
.
dtype
)
def
test_embeddings_as_inputs
(
self
):
hidden_size
=
32
sequence_length
=
21
test_network
=
fnet
.
FNet
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
test_network
.
build
(
dict
(
input_word_ids
=
word_ids
,
input_mask
=
mask
,
input_type_ids
=
type_ids
))
embeddings
=
test_network
.
get_embedding_layer
()(
word_ids
)
# Calls with the embeddings.
dict_outputs
=
test_network
(
dict
(
input_word_embeddings
=
embeddings
,
input_mask
=
mask
,
input_type_ids
=
type_ids
))
all_encoder_outputs
=
dict_outputs
[
"encoder_outputs"
]
pooled
=
dict_outputs
[
"pooled_output"
]
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertLen
(
all_encoder_outputs
,
3
)
for
data
in
all_encoder_outputs
:
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
# The default output dtype is float32.
self
.
assertAllEqual
(
tf
.
float32
,
all_encoder_outputs
[
-
1
].
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
pooled
.
dtype
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment