Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f3a61a49
Commit
f3a61a49
authored
Dec 04, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 283778598
parent
4270e416
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
597 additions
and
0 deletions
+597
-0
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+250
-0
official/nlp/modeling/layers/transformer_scaffold_test.py
official/nlp/modeling/layers/transformer_scaffold_test.py
+347
-0
No files found.
official/nlp/modeling/layers/transformer_scaffold.py
0 → 100644
View file @
f3a61a49
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based transformer scaffold layer."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
TransformerScaffold
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer scaffold layer.
This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762), with a customizable attention layer
option. Users can pass a class to `attention_cls` and associated config to
`attention_cfg`, in which case the scaffold will instantiate the class with
the config, or pass a class instance to `attention_cls`.
Attributes:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
attention_cls: A class to instantate, or a layer instance.
attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def
__init__
(
self
,
num_attention_heads
,
intermediate_size
,
intermediate_activation
,
attention_cls
=
attention
.
Attention
,
attention_cfg
=
None
,
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
super
(
TransformerScaffold
,
self
).
__init__
(
**
kwargs
)
self
.
_attention_cfg
=
attention_cfg
self
.
_attention_cls
=
attention_cls
self
.
_num_heads
=
num_attention_heads
self
.
_intermediate_size
=
intermediate_size
self
.
_intermediate_activation
=
intermediate_activation
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_dropout_rate
=
dropout_rate
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
input_tensor_shape
=
tf
.
TensorShape
(
input_tensor
)
if
len
(
input_tensor_shape
)
!=
3
:
raise
ValueError
(
"TransformerScaffold expects a three-dimensional input of "
"shape [batch, sequence, width]."
)
batch_size
,
sequence_length
,
hidden_size
=
input_tensor_shape
if
len
(
input_shape
)
==
2
:
mask_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
1
])
expected_mask_tensor_shape
=
tf
.
TensorShape
(
[
batch_size
,
sequence_length
,
sequence_length
])
if
not
expected_mask_tensor_shape
.
is_compatible_with
(
mask_tensor_shape
):
raise
ValueError
(
"When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s."
%
(
expected_mask_tensor_shape
,
mask_tensor_shape
))
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
if
isinstance
(
self
.
_attention_cls
,
tf
.
keras
.
layers
.
Layer
):
self
.
_attention_layer
=
self
.
_attention_cls
else
:
if
self
.
_attention_cfg
is
None
:
attention_cfg
=
{
"num_heads"
:
self
.
_num_heads
,
"head_size"
:
self
.
_attention_head_size
,
"dropout_rate"
:
self
.
_attention_dropout_rate
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
"kernel_regularizer"
:
self
.
_kernel_regularizer
,
"bias_regularizer"
:
self
.
_bias_regularizer
,
"activity_regularizer"
:
self
.
_activity_regularizer
,
"kernel_constraint"
:
self
.
_kernel_constraint
,
"bias_constraint"
:
self
.
_bias_constraint
,
"name"
:
"self_attention"
}
else
:
attention_cfg
=
self
.
_attention_cfg
self
.
_attention_layer
=
self
.
_attention_cls
(
**
attention_cfg
)
self
.
_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention_output"
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
_intermediate_size
,
activation
=
self
.
_intermediate_activation
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
tf
.
float32
,
# This layer is always float32 for numeric stability.
name
=
"intermediate"
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
super
(
TransformerScaffold
,
self
).
build
(
input_shape
)
def
compute_output_shape
(
self
,
input_shape
):
data_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
batch
=
data_tensor_shape
[
0
]
sequence_length
=
data_tensor_shape
[
1
]
return
tf
.
TensorShape
((
batch
,
sequence_length
,
self
.
_output_einsum_shape
))
def
get_config
(
self
):
config
=
{
"attention_cls"
:
self
.
_attention_layer
,
"num_attention_heads"
:
self
.
_num_heads
,
"intermediate_size"
:
self
.
_intermediate_size
,
"intermediate_activation"
:
self
.
_intermediate_activation
,
"dropout_rate"
:
self
.
_dropout_rate
,
"attention_dropout_rate"
:
self
.
_attention_dropout_rate
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
}
base_config
=
super
(
TransformerScaffold
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
if
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
attention_inputs
=
[
input_tensor
,
input_tensor
]
if
attention_mask
is
not
None
:
attention_inputs
.
append
(
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
)
attention_output
=
self
.
_attention_output_dense
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
# Use float32 in keras layer norm and the gelu activation in the
# intermediate dense layer for numeric stability
if
self
.
dtype
==
tf
.
float16
:
input_tensor
=
tf
.
cast
(
input_tensor
,
tf
.
float32
)
attention_output
=
tf
.
cast
(
attention_output
,
tf
.
float32
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
if
self
.
dtype
==
tf
.
float16
:
intermediate_output
=
tf
.
cast
(
intermediate_output
,
tf
.
float16
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# Use float32 in keras layer norm for numeric stability
if
self
.
dtype
==
tf
.
float16
:
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
if
self
.
dtype
==
tf
.
float16
:
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float16
)
return
layer_output
official/nlp/modeling/layers/transformer_scaffold_test.py
0 → 100644
View file @
f3a61a49
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based transformer block layer."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
transformer_scaffold
# Test class that wraps a standard attention layer. If this layer is called
# at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'TestOnly'
)
class
ValidatedAttentionLayer
(
attention
.
Attention
):
def
__init__
(
self
,
call_list
,
**
kwargs
):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
def
call
(
self
,
inputs
):
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
)
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
config
[
'call_list'
]
=
[]
return
config
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
TransformerLayerTest
(
keras_parameterized
.
TestCase
):
def
test_layer_creation
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_creation_with_mask
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_creation_with_incorrect_mask_fails
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
-
3
))
with
self
.
assertRaisesRegex
(
ValueError
,
'When passing a mask tensor.*'
):
_
=
test_layer
([
data_tensor
,
mask_tensor
])
def
test_layer_invocation
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
(
data_tensor
,
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
_
=
model
.
predict
(
input_data
)
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_invocation_with_mask
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_layer_invocation_with_float16_dtype
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
,
dtype
=
'float16'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
),
dtype
=
tf
.
float16
)
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
(
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))).
astype
(
np
.
float16
)
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_transform_with_initializer
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output
=
test_layer
(
data_tensor
)
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output
.
shape
.
as_list
())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
])
def
test_layer_restoration_from_config
(
self
):
sequence_length
=
21
width
=
80
call_list
=
[]
attention_layer_cfg
=
{
'num_heads'
:
10
,
'head_size'
:
8
,
'call_list'
:
call_list
,
'name'
:
'test_layer'
,
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
attention_cls
=
ValidatedAttentionLayer
,
attention_cfg
=
attention_layer_cfg
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
data_tensor
,
mask_tensor
])
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
([
data_tensor
,
mask_tensor
],
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
pre_serialization_output
=
model
.
predict
([
input_data
,
mask_data
])
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data
=
json
.
dumps
(
model
.
get_config
())
post_string_serialized_data
=
json
.
loads
(
serialized_data
)
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model
=
tf
.
keras
.
Model
.
from_config
(
post_string_serialized_data
)
new_model
.
set_weights
(
model
.
get_weights
())
output
=
new_model
.
predict
([
input_data
,
mask_data
])
self
.
assertAllClose
(
pre_serialization_output
,
output
)
# If the layer was configured correctly, it should have a list attribute
# (since it should have the custom class and config passed to it).
new_model
.
summary
()
new_call_list
=
new_model
.
get_layer
(
name
=
'transformer_scaffold'
).
_attention_layer
.
list
self
.
assertNotEmpty
(
new_call_list
)
self
.
assertTrue
(
new_call_list
[
0
],
"The passed layer class wasn't instantiated."
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment