Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7b2df02f
Commit
7b2df02f
authored
Mar 10, 2022
by
Jinoo Baek
Committed by
A. Unique TensorFlower
Mar 10, 2022
Browse files
Internal change
PiperOrigin-RevId: 433892706
parent
4f93d3b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
399 additions
and
9 deletions
+399
-9
official/nlp/modeling/layers/transformer_encoder_block.py
official/nlp/modeling/layers/transformer_encoder_block.py
+85
-8
official/nlp/modeling/layers/transformer_encoder_block_test.py
...ial/nlp/modeling/layers/transformer_encoder_block_test.py
+314
-1
No files found.
official/nlp/modeling/layers/transformer_encoder_block.py
View file @
7b2df02f
...
...
@@ -54,9 +54,31 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
use_query_residual
=
True
,
key_dim
=
None
,
value_dim
=
None
,
output_last_dim
=
None
,
diff_q_kv_att_layer_norm
=
False
,
**
kwargs
):
"""Initializes `TransformerEncoderBlock`.
Note: If `output_last_dim` is used and `use_query_residual` is `True`, the
`output_last_dim`'s value must equal the first input's last dimension for
the query residual connection to work. This is because the residual
connection after the multi-head-attention requires their dimensions to
match. If `use_query_residual` is `False`, the `output_last_dim` dictactes
the last dimension of the output of this module and the
multi-head-attention.
E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`.
Scenario 1: If `output_last_dim` is not `None`, then the output dims of this
module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is
is overriden by `output_last_dim`.
Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then
the output dims of this module would be `[batch_size, seq_dim, key_dim]`.
Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the
output dims would be `[batch_size, seq_dim, input_last_dim]`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
...
...
@@ -88,6 +110,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
use_query_residual: Toggle to execute residual connection after attention.
key_dim: `key_dim` for the `tf.keras.layers.MultiHeadAttention`. If
`None`, we use the first `input_shape`'s last dim.
value_dim: `value_dim` for the `tf.keras.layers.MultiHeadAttention`.
output_last_dim: Final dimension of the output of this module. This also
dictates the value for the final dimension of the
multi-head-attention. When it's `None`, we use, in order of decreasing
precedence, `key_dim` * `num_heads` or the first `input_shape`'s last
dim as the output's last dim.
diff_q_kv_att_layer_norm: If `True`, create a separate attention layer
norm layer for query and key-value if `norm_first` is `True`. Invalid
to set to `True` if `norm_first` is `False`.
**kwargs: keyword arguments.
"""
util
.
filter_kwargs
(
kwargs
)
...
...
@@ -112,6 +146,11 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_inner_dropout
=
inner_dropout
self
.
_use_query_residual
=
use_query_residual
self
.
_key_dim
=
key_dim
self
.
_value_dim
=
value_dim
self
.
_output_last_dim
=
output_last_dim
self
.
_diff_q_kv_att_layer_norm
=
diff_q_kv_att_layer_norm
if
attention_initializer
:
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
...
...
@@ -119,6 +158,10 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
if
self
.
_diff_q_kv_att_layer_norm
and
not
self
.
_norm_first
:
raise
ValueError
(
"Setting `diff_q_and_kv_attention_layer_norm` to True"
"when `norm_first` is False is invalid."
)
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
...
...
@@ -136,7 +179,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
raise
ValueError
(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
if
self
.
_key_dim
is
None
:
self
.
_key_dim
=
int
(
hidden_size
//
self
.
_num_heads
)
if
self
.
_output_last_dim
is
None
:
last_output_shape
=
hidden_size
else
:
last_output_shape
=
self
.
_output_last_dim
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
...
@@ -146,11 +195,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
bias_constraint
=
self
.
_bias_constraint
)
self
.
_attention_layer
=
tf
.
keras
.
layers
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_attention_head_size
,
key_dim
=
self
.
_key_dim
,
value_dim
=
self
.
_value_dim
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
attention_axes
=
self
.
_attention_axes
,
output_shape
=
self
.
_output_last_dim
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
...
...
@@ -162,6 +213,15 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
self
.
_attention_layer_norm_kv
=
self
.
_attention_layer_norm
if
self
.
_diff_q_kv_att_layer_norm
:
self
.
_attention_layer_norm_kv
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm_kv"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
...
...
@@ -181,7 +241,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
rate
=
self
.
_inner_dropout
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
hidden_siz
e
),
output_shape
=
(
None
,
last_output_shap
e
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
...
...
@@ -235,6 +295,16 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"attention_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
"use_query_residual"
:
self
.
_use_query_residual
,
"key_dim"
:
self
.
_key_dim
,
"value_dim"
:
self
.
_value_dim
,
"output_last_dim"
:
self
.
_output_last_dim
,
"diff_q_kv_att_layer_norm"
:
self
.
_diff_q_kv_att_layer_norm
,
}
base_config
=
super
(
TransformerEncoderBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
@@ -271,7 +341,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
key_value
=
self
.
_attention_layer_norm
_kv
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
...
...
@@ -280,7 +350,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
key_value
=
self
.
_attention_layer_norm
_kv
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
...
...
@@ -288,11 +358,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
# Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if
self
.
_use_query_residual
:
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_use_query_residual
:
attention_output
=
target_tensor
+
attention_output
attention_output
=
self
.
_attention_layer_norm
(
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
...
...
official/nlp/modeling/layers/transformer_encoder_block_test.py
View file @
7b2df02f
...
...
@@ -252,6 +252,182 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self
.
assertEqual
(
output
.
shape
,
q_tensor
.
shape
)
@
keras_parameterized
.
run_all_keras_modes
class
TransformerEncoderBlockLayerTestWithoutParams
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
super
(
TransformerEncoderBlockLayerTestWithoutParams
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
def
test_raises_invalid_arg_error_when_q_kv_dims_are_different
(
self
):
test_layer
=
TransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
'relu'
,
norm_first
=
True
)
# Forward path.
q_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
32
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
inputs
=
[
q_tensor
,
kv_tensor
,
dummy_mask
]
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
test_layer
(
inputs
)
@
parameterized
.
named_parameters
(
(
'output_range_not_none'
,
2
),
(
'output_range_none'
,
None
))
def
test_needs_diff_q_kv_att_layer_norm_to_be_true_for_diff_q_and_kv_dims
(
self
,
output_range
):
test_layer
=
TransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
'relu'
,
output_range
=
output_range
,
norm_first
=
True
)
# Forward path.
q_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
32
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
inputs
=
[
q_tensor
,
kv_tensor
,
dummy_mask
]
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
test_layer
(
inputs
)
test_layer
=
TransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
'relu'
,
diff_q_kv_att_layer_norm
=
True
,
norm_first
=
True
)
# Forward path.
test_layer
(
inputs
)
@
parameterized
.
named_parameters
(
(
'norm_first_is_true'
,
True
),
(
'norm_first_is_false'
,
False
))
def
test_use_query_residual_false_removes_add_op
(
self
,
norm_first
):
graph_with_res
=
tf
.
Graph
()
with
graph_with_res
.
as_default
():
layer
=
TransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
'relu'
,
norm_first
=
norm_first
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
None
,
None
,
2
))
outputs
=
layer
(
inputs
)
tf
.
keras
.
Model
(
inputs
=
inputs
,
outputs
=
outputs
)
graph_without_res
=
tf
.
Graph
()
with
graph_without_res
.
as_default
():
layer
=
TransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
'relu'
,
norm_first
=
norm_first
,
use_query_residual
=
False
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
None
,
None
,
2
))
outputs
=
layer
(
inputs
)
tf
.
keras
.
Model
(
inputs
=
inputs
,
outputs
=
outputs
)
graph_with_res_names
=
{
x
.
name
for
x
in
graph_with_res
.
get_operations
()}
graph_without_res_names
=
{
x
.
name
for
x
in
graph_without_res
.
get_operations
()
}
self
.
assertIn
(
'transformer_encoder_block/add'
,
list
(
graph_with_res_names
-
graph_without_res_names
)[
0
])
self
.
assertEmpty
(
graph_without_res_names
-
graph_with_res_names
)
@
parameterized
.
named_parameters
(
(
'key_dim_is_none'
,
None
,
128
,
2
,
128
//
2
),
(
'key_dim_is_not_none'
,
30
,
128
,
2
,
30
))
def
test_key_dim
(
self
,
key_dim
,
q_tensor_last_dim
,
some_num_attention_heads
,
expected
):
some_inner_dim
=
32
some_inner_activation
=
'relu'
test_layer
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
key_dim
=
key_dim
)
q_tensor
=
tf
.
zeros
([
2
,
4
,
q_tensor_last_dim
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
32
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
test_layer
([
q_tensor
,
kv_tensor
,
dummy_mask
])
self
.
assertEqual
(
expected
,
test_layer
.
_attention_layer
.
get_config
()[
'key_dim'
])
@
parameterized
.
named_parameters
(
(
'output_last_dim_is_none_use_query_residual_false'
,
False
,
None
,
128
,
128
),
(
'output_last_dim_is_none_use_query_residual_true'
,
True
,
None
,
128
,
128
),
(
'output_last_dim_is_not_none'
,
False
,
30
,
128
,
30
))
def
test_output_last_dim
(
self
,
use_query_residual
,
output_last_dim
,
q_tensor_last_dim
,
expected
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
test_layer
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
# Must be false for multi-head output to be different from
# first input's last dim
use_query_residual
=
use_query_residual
,
output_last_dim
=
output_last_dim
)
q_tensor
=
tf
.
zeros
([
2
,
4
,
q_tensor_last_dim
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
32
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
output
=
test_layer
([
q_tensor
,
kv_tensor
,
dummy_mask
])
self
.
assertEqual
(
output
.
numpy
().
shape
[
-
1
],
expected
)
@
parameterized
.
named_parameters
(
(
'value_dim_is_none'
,
None
,
128
,
2
,
128
//
2
),
(
'value_dim_is_not_none'
,
30
,
128
,
2
,
30
))
def
test_value_dim
(
self
,
value_dim
,
q_tensor_last_dim
,
some_num_attention_heads
,
expected
):
some_inner_dim
=
32
some_inner_activation
=
'relu'
test_layer
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
value_dim
=
value_dim
)
q_tensor
=
tf
.
zeros
([
2
,
4
,
q_tensor_last_dim
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
32
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
test_layer
([
q_tensor
,
kv_tensor
,
dummy_mask
])
self
.
assertEqual
(
expected
,
test_layer
.
_attention_layer
.
get_config
()[
'value_dim'
])
@
keras_parameterized
.
run_all_keras_modes
class
TransformerArgumentTest
(
keras_parameterized
.
TestCase
):
...
...
@@ -277,6 +453,138 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
output
=
encoder_block
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
def
test_norm_first_false_and_diff_q_kv_att_layer_norm_true_raises
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
with
self
.
assertRaises
(
ValueError
):
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
norm_first
=
False
,
diff_q_kv_att_layer_norm
=
True
)
def
test_diff_q_kv_att_layer_norm_is_part_of_config_1
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
norm_first
=
False
)
self
.
assertIn
(
'diff_q_kv_att_layer_norm'
,
encoder
.
get_config
())
self
.
assertFalse
(
encoder
.
get_config
()[
'diff_q_kv_att_layer_norm'
])
def
test_diff_q_kv_att_layer_norm_is_part_of_config_2
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
norm_first
=
True
,
diff_q_kv_att_layer_norm
=
True
)
self
.
assertIn
(
'diff_q_kv_att_layer_norm'
,
encoder
.
get_config
())
self
.
assertTrue
(
encoder
.
get_config
()[
'diff_q_kv_att_layer_norm'
])
def
test_use_query_residual_is_part_of_config_1
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
)
self
.
assertIn
(
'use_query_residual'
,
encoder
.
get_config
())
self
.
assertTrue
(
encoder
.
get_config
()[
'use_query_residual'
])
def
test_use_query_residual_is_part_of_config_2
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
use_query_residual
=
False
)
self
.
assertIn
(
'use_query_residual'
,
encoder
.
get_config
())
self
.
assertFalse
(
encoder
.
get_config
()[
'use_query_residual'
])
def
test_key_dim_is_part_of_config_1
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
)
self
.
assertIn
(
'key_dim'
,
encoder
.
get_config
())
self
.
assertIsNone
(
encoder
.
get_config
()[
'key_dim'
])
def
test_key_dim_is_part_of_config_2
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
key_dim
=
10
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
key_dim
=
key_dim
)
self
.
assertIn
(
'key_dim'
,
encoder
.
get_config
())
self
.
assertEqual
(
key_dim
,
encoder
.
get_config
()[
'key_dim'
])
def
test_value_dim_is_part_of_config_1
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
)
self
.
assertIn
(
'value_dim'
,
encoder
.
get_config
())
self
.
assertIsNone
(
encoder
.
get_config
()[
'value_dim'
])
def
test_value_dim_is_part_of_config_2
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
value_dim
=
10
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
value_dim
=
value_dim
)
self
.
assertIn
(
'value_dim'
,
encoder
.
get_config
())
self
.
assertEqual
(
value_dim
,
encoder
.
get_config
()[
'value_dim'
])
def
test_output_last_dim_is_part_of_config_1
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
)
self
.
assertIn
(
'output_last_dim'
,
encoder
.
get_config
())
self
.
assertIsNone
(
encoder
.
get_config
()[
'output_last_dim'
])
def
test_output_last_dim_is_part_of_config_2
(
self
):
some_num_attention_heads
=
2
some_inner_dim
=
32
some_inner_activation
=
'relu'
output_last_dim
=
10
encoder
=
TransformerEncoderBlock
(
num_attention_heads
=
some_num_attention_heads
,
inner_dim
=
some_inner_dim
,
inner_activation
=
some_inner_activation
,
output_last_dim
=
output_last_dim
)
self
.
assertIn
(
'output_last_dim'
,
encoder
.
get_config
())
self
.
assertEqual
(
output_last_dim
,
encoder
.
get_config
()[
'output_last_dim'
])
def
test_get_config
(
self
):
num_attention_heads
=
2
encoder_block
=
TransformerEncoderBlock
(
...
...
@@ -290,7 +598,12 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
norm_epsilon
=
1e-6
,
inner_dropout
=
0.1
,
attention_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=
0.
,
maxval
=
1.
))
minval
=
0.
,
maxval
=
1.
),
use_query_residual
=
False
,
key_dim
=
20
,
value_dim
=
30
,
output_last_dim
=
40
,
diff_q_kv_att_layer_norm
=
True
)
encoder_block_config
=
encoder_block
.
get_config
()
new_encoder_block
=
TransformerEncoderBlock
.
from_config
(
encoder_block_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment