Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c60499b1
Commit
c60499b1
authored
Oct 13, 2022
by
A. Unique TensorFlower
Browse files
Implement the Vision TransformerScaffold which is a subclass from the NLP TransformerScaffold.
PiperOrigin-RevId: 480969429
parent
ad480628
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
679 additions
and
0 deletions
+679
-0
official/projects/vit/modeling/transformer_scaffold.py
official/projects/vit/modeling/transformer_scaffold.py
+161
-0
official/projects/vit/modeling/transformer_scaffold_test.py
official/projects/vit/modeling/transformer_scaffold_test.py
+518
-0
No files found.
official/projects/vit/modeling/transformer_scaffold.py
0 → 100644
View file @
c60499b1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based Scaffold TransformerEncoder block for vision models.
This implementation is subclassed from NLP TransformerScaffold to support
customized `attention_layer` and `feedforward_layer`. In addition, this
implementation has a few features to better support vision use cases:
1. `stochastic_depth_drop_rate` to supress model overfitting.
2. `return_attention_scores`, optionally returns the attention output.
3. `ffn_has_residual_connection`, clearly define whether feedforward network has
residual connection or not to avoid ambiguity.
"""
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
gin
import
tensorflow
as
tf
from
official.nlp
import
modeling
from
official.vision.modeling.layers.nn_layers
import
StochasticDepth
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Vision"
)
@
gin
.
configurable
class
TransformerScaffold
(
modeling
.
layers
.
TransformerScaffold
):
"""TransformerScaffold layer for vision applications.
This layer is a subclass of NLP TransformerScaffold:
Attributes:
stochastic_depth_drop_rate: Drop rate for the residual connections.
return_attention_scores: Optionally return the attention output.
ffn_has_residual_connection: Whether the feedforward network has internal
residual connection and layer norm. If False, the residual connection and
the layer norm op are called inside TransformerScaffold.
"""
def
__init__
(
self
,
*
args
,
stochastic_depth_drop_rate
:
float
=
0.0
,
return_attention_scores
:
bool
=
False
,
ffn_has_residual_connection
:
bool
=
False
,
**
kwargs
):
"""Initializes TransformerEncoderBlock."""
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_return_attention_scores
=
return_attention_scores
self
.
_ffn_has_residual_connection
=
ffn_has_residual_connection
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
int
]]):
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
lambda
x
,
*
args
,
**
kwargs
:
tf
.
identity
(
x
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"stochastic_depth_drop_rate"
:
self
.
_stochastic_depth_drop_rate
,
"return_attention_scores"
:
self
.
_return_attention_scores
,
"ffn_has_residual_connection"
:
self
.
_ffn_has_residual_connection
}
base_config
=
super
().
get_config
()
base_config
.
update
(
config
)
return
base_config
def
call
(
self
,
inputs
:
tf
.
Tensor
,
training
:
Optional
[
bool
]
=
None
)
->
Union
[
tf
.
Tensor
,
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]:
"""Transformer self-attention encoder block call."""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
key_value
is
None
:
key_value
=
input_tensor
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
,
training
=
training
)
attention_layer_output
=
self
.
_attention_layer
(
query
=
input_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
,
training
=
training
,
return_attention_scores
=
self
.
_return_attention_scores
)
if
isinstance
(
attention_layer_output
,
tuple
):
# `attention_layer_output` contains two tensors when
# `return_attention_scores` is True.
attention_output
,
attention_scores
=
attention_layer_output
else
:
attention_output
=
attention_layer_output
attention_output
=
self
.
_attention_dropout
(
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
source_attention_output
=
source_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
)
attention_output
=
self
.
_output_layer_norm
(
source_attention_output
,
training
=
training
)
else
:
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
),
training
=
training
)
if
self
.
_feedforward_block
is
None
:
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
,
training
=
training
)
layer_output
=
self
.
_output_dropout
(
layer_output
,
training
=
training
)
else
:
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
if
self
.
_norm_first
:
if
self
.
_ffn_has_residual_connection
:
raise
ValueError
(
"In the case of `norm_first`, the residual connection should be"
"done in the TransformerScaffold call function, not FFN's"
"call function."
)
output
=
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
training
)
else
:
if
self
.
_ffn_has_residual_connection
:
output
=
self
.
_stochastic_depth
(
layer_output
,
training
=
training
)
else
:
output
=
self
.
_output_layer_norm
(
attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
training
))
if
self
.
_return_attention_scores
:
return
output
,
attention_scores
else
:
return
output
official/projects/vit/modeling/transformer_scaffold_test.py
0 → 100644
View file @
c60499b1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based transformer block layer."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp
import
modeling
from
official.projects.vit.modeling
import
transformer_scaffold
TransformerScaffold
=
transformer_scaffold
.
TransformerScaffold
# Test class that wraps a standard attention layer. If this layer is called
# at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'TestOnlyAttention'
)
class
ValidatedAttentionLayer
(
modeling
.
layers
.
attention
.
MultiHeadAttention
):
def
__init__
(
self
,
call_list
,
**
kwargs
):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
def
call
(
self
,
query
,
value
,
attention_mask
=
None
,
return_attention_scores
=
False
,):
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
query
,
value
,
attention_mask
=
attention_mask
,
return_attention_scores
=
return_attention_scores
)
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
config
[
'call_list'
]
=
[]
return
config
# Test class implements a simple feedforward layer. If this layer is called
# at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'TestOnlyFeedforward'
)
class
ValidatedFeedforwardLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
call_list
,
activation
,
**
kwargs
):
super
(
ValidatedFeedforwardLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
activation
=
activation
def
build
(
self
,
input_shape
):
hidden_size
=
input_shape
.
as_list
()[
-
1
]
self
.
_feedforward_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
activation
=
self
.
activation
,
name
=
'feedforward'
)
def
call
(
self
,
inputs
):
self
.
list
.
append
(
True
)
return
self
.
_feedforward_dense
(
inputs
)
def
get_config
(
self
):
config
=
super
(
ValidatedFeedforwardLayer
,
self
).
get_config
()
config
[
'call_list'
]
=
[]
config
[
'activation'
]
=
self
.
activation
return
config
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
TransformerLayerTest
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
super
(
TransformerLayerTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
def
test_layer_creation
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_creation_with_feedforward_cls
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
feedforward_call_list
=
[]
feedforward_layer_cfg
=
{
'activation'
:
'relu'
,
'call_list'
:
feedforward_call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
feedforward_cls
=
ValidatedFeedforwardLayer
,
feedforward_cfg
=
feedforward_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
None
,
inner_activation
=
None
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
self
.
assertNotEmpty
(
feedforward_call_list
)
self
.
assertTrue
(
feedforward_call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_creation_with_mask
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_invocation
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
(
data_tensor
,
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
_
=
model
.
predict
(
input_data
)
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_invocation_with_feedforward_cls
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
feedforward_call_list
=
[]
feedforward_layer_cfg
=
{
'activation'
:
'relu'
,
'call_list'
:
feedforward_call_list
,
}
feedforward_layer
=
ValidatedFeedforwardLayer
(
**
feedforward_layer_cfg
)
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
feedforward_cls
=
feedforward_layer
,
num_attention_heads
=
10
,
inner_dim
=
None
,
inner_activation
=
None
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
self
.
assertNotEmpty
(
feedforward_call_list
)
self
.
assertTrue
(
feedforward_call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_invocation_with_mask
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_invocation_with_float16_dtype
(
self
):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_float16'
)
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
(
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_transform_with_initializer
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output
=
test_layer
(
data_tensor
)
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
])
def
test_layer_restoration_from_config
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
'name'
:
'test_layer'
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
pre_serialization_output
=
model
.
predict
([
input_data
,
mask_data
])
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data
=
model
.
get_config
()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model
=
tf
.
keras
.
Model
.
from_config
(
serialized_data
)
new_model
.
set_weights
(
model
.
get_weights
())
output
=
new_model
.
predict
([
input_data
,
mask_data
])
self
.
assertAllClose
(
pre_serialization_output
,
output
)
# If the layer was configured correctly, it should have a list attribute
# (since it should have the custom class and config passed to it).
new_model
.
summary
()
new_call_list
=
new_model
.
get_layer
(
name
=
'transformer_scaffold'
).
_attention_layer
.
list
self
.
assertNotEmpty
(
new_call_list
)
self
.
assertTrue
(
new_call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_with_feedforward_cls_restoration_from_config
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'key_dim'
:
8
,
'call_list'
:
call_list
,
'name'
:
'test_layer'
,
}
feedforward_call_list
=
[]
feedforward_layer_cfg
=
{
'activation'
:
'relu'
,
'call_list'
:
feedforward_call_list
,
}
test_layer
=
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
feedforward_cls
=
ValidatedFeedforwardLayer
,
feedforward_cfg
=
feedforward_layer_cfg
,
num_attention_heads
=
10
,
inner_dim
=
None
,
inner_activation
=
None
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
pre_serialization_output
=
model
.
predict
([
input_data
,
mask_data
])
serialized_data
=
model
.
get_config
()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model
=
tf
.
keras
.
Model
.
from_config
(
serialized_data
)
new_model
.
set_weights
(
model
.
get_weights
())
output
=
new_model
.
predict
([
input_data
,
mask_data
])
self
.
assertAllClose
(
pre_serialization_output
,
output
)
# If the layer was configured correctly, it should have a list attribute
# (since it should have the custom class and config passed to it).
new_model
.
summary
()
new_call_list
=
new_model
.
get_layer
(
name
=
'transformer_scaffold'
).
_attention_layer
.
list
self
.
assertNotEmpty
(
new_call_list
)
self
.
assertTrue
(
new_call_list
[
0
],
"The passed layer class wasn't instantiated."
)
new_feedforward_call_list
=
new_model
.
get_layer
(
name
=
'transformer_scaffold'
).
_feedforward_block
.
list
self
.
assertNotEmpty
(
new_feedforward_call_list
)
self
.
assertTrue
(
new_feedforward_call_list
[
0
],
"The passed layer class wasn't instantiated."
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment