Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
772
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
556 additions
and
158 deletions
+556
-158
official/nlp/modeling/layers/per_dim_scale_attention.py
official/nlp/modeling/layers/per_dim_scale_attention.py
+101
-0
official/nlp/modeling/layers/per_dim_scale_attention_test.py
official/nlp/modeling/layers/per_dim_scale_attention_test.py
+52
-0
official/nlp/modeling/layers/position_embedding.py
official/nlp/modeling/layers/position_embedding.py
+3
-3
official/nlp/modeling/layers/position_embedding_test.py
official/nlp/modeling/layers/position_embedding_test.py
+1
-1
official/nlp/modeling/layers/relative_attention.py
official/nlp/modeling/layers/relative_attention.py
+24
-25
official/nlp/modeling/layers/relative_attention_test.py
official/nlp/modeling/layers/relative_attention_test.py
+1
-1
official/nlp/modeling/layers/reuse_attention.py
official/nlp/modeling/layers/reuse_attention.py
+48
-27
official/nlp/modeling/layers/reuse_attention_test.py
official/nlp/modeling/layers/reuse_attention_test.py
+1
-1
official/nlp/modeling/layers/reuse_transformer.py
official/nlp/modeling/layers/reuse_transformer.py
+12
-7
official/nlp/modeling/layers/reuse_transformer_test.py
official/nlp/modeling/layers/reuse_transformer_test.py
+14
-14
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+44
-27
official/nlp/modeling/layers/rezero_transformer_test.py
official/nlp/modeling/layers/rezero_transformer_test.py
+4
-1
official/nlp/modeling/layers/routing.py
official/nlp/modeling/layers/routing.py
+125
-0
official/nlp/modeling/layers/routing_test.py
official/nlp/modeling/layers/routing_test.py
+59
-0
official/nlp/modeling/layers/self_attention_mask.py
official/nlp/modeling/layers/self_attention_mask.py
+31
-25
official/nlp/modeling/layers/spectral_normalization.py
official/nlp/modeling/layers/spectral_normalization.py
+11
-12
official/nlp/modeling/layers/spectral_normalization_test.py
official/nlp/modeling/layers/spectral_normalization_test.py
+2
-2
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+5
-3
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+1
-1
official/nlp/modeling/layers/text_layers.py
official/nlp/modeling/layers/text_layers.py
+17
-8
No files found.
Too many changes to show.
To preserve performance only
772 of 772+
files are displayed.
Plain diff
Email patch
official/nlp/modeling/layers/per_dim_scale_attention.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based attention layer with learnable per dim scaling."""
import
gin
import
numpy
as
np
import
tensorflow
as
tf
@
gin
.
configurable
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
PerDimScaleAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""Learn scales for individual dims.
It can improve quality but might hurt training stability.
"""
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
# pytype: disable=attribute-error
self
.
_scale_dim
=
self
.
_key_dim
with
tf
.
init_scope
():
self
.
per_dim_scale
=
self
.
add_weight
(
name
=
'per_dim_scale'
,
shape
=
(
self
.
_scale_dim
,),
initializer
=
'zeros'
,
dtype
=
self
.
dtype
,
trainable
=
True
)
def
_scale_query
(
self
,
query
):
# 1.0/tf.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
# can avoid unnecessary XLA op fusion mess on TPU.
r_softplus_0
=
1.442695041
scale
=
tf
.
constant
(
r_softplus_0
/
np
.
sqrt
(
float
(
self
.
_scale_dim
)),
dtype
=
query
.
dtype
)
scale
*=
tf
.
nn
.
softplus
(
self
.
per_dim_scale
)
return
query
*
scale
def
_compute_attention
(
self
,
query
,
key
,
value
,
attention_mask
=
None
,
training
=
None
):
query
=
self
.
_scale_query
(
query
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
attention_scores
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
attention_scores_dropout
=
self
.
_dropout_layer
(
attention_scores
,
training
=
training
)
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores_dropout
,
value
)
return
attention_output
,
attention_scores
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
return_attention_scores
=
False
,
training
=
None
,
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key` = [B, S, N, H]
key
=
self
.
_key_dense
(
key
)
# `value` = [B, S, N, H]
value
=
self
.
_value_dense
(
value
)
attention_output
,
attention_scores
=
self
.
_compute_attention
(
query
,
key
,
value
,
attention_mask
,
training
)
attention_output
=
self
.
_output_dense
(
attention_output
)
if
return_attention_scores
:
return
attention_output
,
attention_scores
return
attention_output
official/nlp/modeling/layers/per_dim_scale_attention_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PerDimScaleAttention."""
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
per_dim_scale_attention
as
attention
class
PerDimScaleAttentionTest
(
tf
.
test
.
TestCase
):
def
test_attention
(
self
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
test_layer
=
attention
.
PerDimScaleAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
*
num_heads
))
value
=
query
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
*
num_heads
])
def
test_config
(
self
):
num_heads
=
12
key_dim
=
64
test_layer
=
attention
.
PerDimScaleAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
)
print
(
test_layer
.
get_config
())
new_layer
=
attention
.
PerDimScaleAttention
.
from_config
(
test_layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/layers/position_embedding.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -53,7 +53,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
...
@@ -53,7 +53,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
seq_axis
=
1
,
seq_axis
=
1
,
**
kwargs
):
**
kwargs
):
super
(
PositionEmbedding
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
max_length
is
None
:
if
max_length
is
None
:
raise
ValueError
(
raise
ValueError
(
"`max_length` must be an Integer, not `None`."
"`max_length` must be an Integer, not `None`."
...
@@ -81,7 +81,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
...
@@ -81,7 +81,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
shape
=
[
weight_sequence_length
,
width
],
shape
=
[
weight_sequence_length
,
width
],
initializer
=
self
.
_initializer
)
initializer
=
self
.
_initializer
)
super
(
PositionEmbedding
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
input_shape
=
tf
.
shape
(
inputs
)
input_shape
=
tf
.
shape
(
inputs
)
...
...
official/nlp/modeling/layers/position_embedding_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/relative_attention.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
`[B, L, dim]`.
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
segment_encoding: Optional `Tensor` representing the segmentation
encoding
encoding
as used in XLNet of shape `[2, num_heads, dim]`.
as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
segment_attention_bias: Optional trainable bias parameter added to the
query
query
had when calculating the segment-based attention score used in
had when calculating the segment-based attention score used in
XLNet of
XLNet of
shape `[num_heads, dim]`.
shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory.
state or memory.
If passed, this is also attended over as in Transformer
If passed, this is also attended over as in Transformer
XL.
XL.
attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions.
to certain positions.
"""
"""
...
@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
with
tf
.
init_scope
():
with
tf
.
init_scope
():
einsum_equation
,
_
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
_
,
output_rank
=
_build_proj_equation
(
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_encoding_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_encoding_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
[
self
.
_num_heads
,
self
.
_key_dim
]),
...
@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
Args:
Args:
query: attention input.
query: attention input.
value: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
content_attention_bias: A trainable bias parameter added to the query
head
head
when calculating the content-based attention score.
when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
head when calculating the position-based attention score.
key: attention input.
key: attention input.
...
@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value.
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
segment_encoding: Optional `Tensor` representing the segmentation
encoding
encoding
as used in XLNet.
as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
query had when calculating the segment-based attention score used in
XLNet.
XLNet.
...
@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
...
@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_stream: The content representation, commonly referred to as h.
content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in
This serves a similar role to the standard hidden states in
Transformer-XL.
Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query
content_attention_bias: A trainable bias parameter added to the query
head
head
when calculating the content-based attention score.
when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g.
query_stream: The query representation, commonly referred to as g.
This
This
only has access to contextual information and position, but not
only has access to contextual information and position, but not
content.
content.
If not provided, then this is MultiHeadRelativeAttention with
If not provided, then this is MultiHeadRelativeAttention with
self-attention.
self-attention.
relative_position_encoding: relative positional encoding for key and
relative_position_encoding: relative positional encoding for key and
value.
value.
target_mapping: Optional `Tensor` representing the target mapping used
target_mapping: Optional `Tensor` representing the target mapping used
in
in
partial prediction.
partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
segment_encoding: Optional `Tensor` representing the segmentation
encoding
encoding
as used in XLNet.
as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score.
query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended
state: (default None) optional state. If passed, this is also attended
...
@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
...
@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_attention_mask: (default None) Optional mask that is added to
content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence
content attention logits. If state is not None, the mask source sequence
dimension should extend M.
dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to
query_attention_mask: (default None) Optional mask that is added to
query
query
attention logits. If state is not None, the mask source sequence
attention logits. If state is not None, the mask source sequence
dimension should extend M.
dimension should extend M.
Returns:
Returns:
...
@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
...
@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
query_attention_output
=
self
.
_output_dense
(
query_attention_output
)
query_attention_output
=
self
.
_output_dense
(
query_attention_output
)
return
content_attention_output
,
query_attention_output
return
content_attention_output
,
query_attention_output
official/nlp/modeling/layers/relative_attention_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/reuse_attention.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -22,6 +22,8 @@ import string
...
@@ -22,6 +22,8 @@ import string
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
_CHR_IDX
=
string
.
ascii_lowercase
_CHR_IDX
=
string
.
ascii_lowercase
...
@@ -221,7 +223,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -221,7 +223,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
kernel_constraint
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
**
kwargs
):
super
(
ReuseMultiHeadAttention
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_num_heads
=
num_heads
self
.
_key_dim
=
key_dim
self
.
_key_dim
=
key_dim
self
.
_value_dim
=
value_dim
if
value_dim
else
key_dim
self
.
_value_dim
=
value_dim
if
value_dim
else
key_dim
...
@@ -299,7 +301,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -299,7 +301,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
"key_shape"
:
self
.
_key_shape
,
"key_shape"
:
self
.
_key_shape
,
"value_shape"
:
self
.
_value_shape
,
"value_shape"
:
self
.
_value_shape
,
}
}
base_config
=
super
(
ReuseMultiHeadAttention
,
self
).
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
classmethod
@
classmethod
...
@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self
.
_key_shape
=
tf
.
TensorShape
(
key
)
self
.
_key_shape
=
tf
.
TensorShape
(
key
)
common_kwargs
=
dict
(
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
...
@@ -362,42 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -362,42 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
if
self
.
_reuse_heads
<
self
.
_num_heads
:
if
self
.
_reuse_heads
<
self
.
_num_heads
:
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_query_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
output_shape
=
_get_output_shape
(
self
.
_num_heads
-
self
.
_reuse_heads
,
self
.
_key_dim
]),
output_rank
-
1
,
[
self
.
_num_heads
-
self
.
_reuse_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"query"
,
name
=
"query"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
self
.
_key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_key_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
output_shape
=
_get_output_shape
(
self
.
_num_heads
-
self
.
_reuse_heads
,
self
.
_key_dim
]),
output_rank
-
1
,
[
self
.
_num_heads
-
self
.
_reuse_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"key"
,
name
=
"key"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
self
.
_value_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_value_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_value_dense
=
[]
self
.
_value_dense
=
[]
if
self
.
_reuse_heads
>
0
:
if
self
.
_reuse_heads
>
0
:
self
.
_value_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_value_dense
.
append
(
einsum_equation
,
tf
.
keras
.
layers
.
EinsumDense
(
output_shape
=
_get_output_shape
(
einsum_equation
,
output_rank
-
1
,
[
self
.
_reuse_heads
,
self
.
_value_dim
]),
output_shape
=
_get_output_shape
(
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
output_rank
-
1
,
[
self
.
_reuse_heads
,
self
.
_value_dim
]),
name
=
"value_reuse"
,
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
**
common_kwargs
))
name
=
"value_reuse"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
))
if
self
.
_reuse_heads
<
self
.
_num_heads
:
if
self
.
_reuse_heads
<
self
.
_num_heads
:
self
.
_value_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_value_dense
.
append
(
einsum_equation
,
tf
.
keras
.
layers
.
EinsumDense
(
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
einsum_equation
,
self
.
_num_heads
-
self
.
_reuse_heads
,
self
.
_value_dim
]),
output_shape
=
_get_output_shape
(
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
output_rank
-
1
,
name
=
"value_new"
,
[
self
.
_num_heads
-
self
.
_reuse_heads
,
self
.
_value_dim
]),
**
common_kwargs
))
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"value_new"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
))
# Builds the attention computations for multi-head dot product attention.
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# These computations could be wrapped into the keras attention layer once
...
@@ -434,18 +453,20 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -434,18 +453,20 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape
=
[
self
.
_query_shape
[
-
1
]]
output_shape
=
[
self
.
_query_shape
[
-
1
]]
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
2
,
output_dims
=
len
(
output_shape
))
free_dims
,
bound_dims
=
2
,
output_dims
=
len
(
output_shape
))
return
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
return
tf
.
keras
.
layers
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
output_shape
),
output_shape
=
_get_output_shape
(
output_rank
-
1
,
output_shape
),
bias_axes
=
bias_axes
if
(
use_bias
and
self
.
_use_bias
)
else
None
,
bias_axes
=
bias_axes
if
(
use_bias
and
self
.
_use_bias
)
else
None
,
name
=
name
,
name
=
name
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
def
_build_attention
(
self
,
rank
):
def
_build_attention
(
self
,
rank
):
"""Builds multi-head dot-product attention computations.
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
This function builds attributes necessary for `_compute_attention` to
c
o
stomize attention computation to replace the default dot-product
c
u
stomize attention computation to replace the default dot-product
attention.
attention.
Args:
Args:
...
...
official/nlp/modeling/layers/reuse_attention_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/reuse_transformer.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer."""
"""Keras-based TransformerEncoder block layer."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
reuse_attention
as
attention
from
official.nlp.modeling.layers
import
reuse_attention
as
attention
...
@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
...
@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
attention_initializer
)
else
:
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
)
self
.
_attention_axes
=
attention_axes
self
.
_attention_axes
=
attention_axes
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
...
@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
...
@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else
:
else
:
self
.
_attention_head_size
=
self
.
_head_size
self
.
_attention_head_size
=
self
.
_head_size
common_kwargs
=
dict
(
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
...
@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
...
@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout
=
self
.
_attention_dropout
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
kernel_initializer
=
self
.
_attention_initializer
,
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
attention_axes
=
self
.
_attention_axes
,
attention_axes
=
self
.
_attention_axes
,
reuse_attention
=
self
.
_reuse_attention
,
reuse_attention
=
self
.
_reuse_attention
,
use_relative_pe
=
self
.
_use_relative_pe
,
use_relative_pe
=
self
.
_use_relative_pe
,
...
@@ -184,11 +187,12 @@ class ReuseTransformer(tf.keras.layers.Layer):
...
@@ -184,11 +187,12 @@ class ReuseTransformer(tf.keras.layers.Layer):
axis
=-
1
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
output_shape
=
(
None
,
self
.
_inner_dim
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
name
=
"intermediate"
,
name
=
"intermediate"
,
**
common_kwargs
)
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
...
@@ -201,12 +205,13 @@ class ReuseTransformer(tf.keras.layers.Layer):
...
@@ -201,12 +205,13 @@ class ReuseTransformer(tf.keras.layers.Layer):
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
rate
=
self
.
_inner_dropout
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
(
None
,
hidden_size
),
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
name
=
"output"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
...
...
official/nlp/modeling/layers/reuse_transformer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -68,7 +68,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -68,7 +68,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
_
=
model
.
predict
(
input_data
)
_
=
model
.
predict
(
input_data
)
...
@@ -89,7 +89,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -89,7 +89,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
# which here is (batch, sequence_length, sequence_length)
...
@@ -104,7 +104,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -104,7 +104,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width
=
80
width
=
80
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
mask_data
=
np
.
random
.
randint
(
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
...
@@ -121,7 +121,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -121,7 +121,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
,
_
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
,
_
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
0.002
,
rtol
=
0.
25
)
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
0.002
,
rtol
=
0.
01
)
def
test_layer_output_range_with_relative_pe
(
self
,
transformer_cls
):
def
test_layer_output_range_with_relative_pe
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
test_layer
=
transformer_cls
(
...
@@ -131,7 +131,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -131,7 +131,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width
=
80
width
=
80
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
mask_data
=
np
.
random
.
randint
(
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
...
@@ -149,7 +149,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -149,7 +149,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
,
_
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
,
_
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.0
03
)
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
0.002
,
rtol
=
0.0
1
)
def
test_layer_output_range_without_mask
(
self
,
transformer_cls
):
def
test_layer_output_range_without_mask
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
test_layer
=
transformer_cls
(
...
@@ -159,7 +159,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -159,7 +159,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width
=
80
width
=
80
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
output_tensor
,
_
=
test_layer
(
input_data
)
output_tensor
,
_
=
test_layer
(
input_data
)
...
@@ -175,7 +175,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -175,7 +175,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
,
_
=
new_layer
(
input_data
)
new_output_tensor
,
_
=
new_layer
(
input_data
)
self
.
assertAllClose
(
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.0
03
)
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
0.002
,
rtol
=
0.0
1
)
def
test_layer_output_range_with_pre_norm
(
self
,
transformer_cls
):
def
test_layer_output_range_with_pre_norm
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
test_layer
=
transformer_cls
(
...
@@ -185,7 +185,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -185,7 +185,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width
=
80
width
=
80
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
mask_data
=
np
.
random
.
randint
(
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
...
@@ -203,7 +203,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -203,7 +203,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
,
_
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
,
_
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.0
03
)
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
0.002
,
rtol
=
0.0
1
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_float16'
)
...
@@ -223,7 +223,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -223,7 +223,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
batch_size
=
6
input_data
=
(
10
*
np
.
random
.
random_sample
(
input_data
=
(
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
)))
(
batch_size
,
sequence_length
,
width
)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
# which here is (batch, sequence_length, sequence_length)
...
@@ -368,7 +368,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -368,7 +368,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
input_data
=
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
# which here is (batch, sequence_length, sequence_length)
...
@@ -404,7 +404,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -404,7 +404,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
6
batch_size
=
6
input_data
=
(
10
*
np
.
random
.
random_sample
(
input_data
=
(
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
)))
(
batch_size
,
sequence_length
,
width
)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
# which here is (batch, sequence_length, sequence_length)
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,10 +14,13 @@
...
@@ -14,10 +14,13 @@
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
from
typing
import
Optional
from
absl
import
logging
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
util
from
official.nlp.modeling.layers
import
util
...
@@ -33,8 +36,10 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -33,8 +36,10 @@ class ReZeroTransformer(tf.keras.layers.Layer):
Args:
Args:
num_attention_heads: Number of attention heads.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
inner_dim: The output dimension of the first Dense layer in a two-layer
intermediate_activation: Activation for the intermediate layer.
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
dropout_rate: Dropout probability for the post-attention and output dropout.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
output_range: the sequence output range, [0, output_range) by slicing the
...
@@ -52,8 +57,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -52,8 +57,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def
__init__
(
self
,
def
__init__
(
self
,
num_attention_heads
,
num_attention_heads
,
in
t
er
mediate_size
,
in
n
er
_dim
=
768
,
in
t
er
mediate
_activation
,
in
n
er
_activation
=
tf_utils
.
get
_activation
(
"gelu"
)
,
dropout_rate
=
0.0
,
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
output_range
=
None
,
output_range
=
None
,
...
@@ -72,12 +77,19 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -72,12 +77,19 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate
=
kwargs
.
pop
(
"attention_dropout"
,
attention_dropout_rate
=
kwargs
.
pop
(
"attention_dropout"
,
attention_dropout_rate
)
attention_dropout_rate
)
dropout_rate
=
kwargs
.
pop
(
"output_dropout"
,
dropout_rate
)
dropout_rate
=
kwargs
.
pop
(
"output_dropout"
,
dropout_rate
)
inner_dim
=
kwargs
.
pop
(
"intermediate_size"
,
inner_dim
)
inner_activation
=
kwargs
.
pop
(
"intermediate_activation"
,
inner_activation
)
util
.
filter_kwargs
(
kwargs
)
util
.
filter_kwargs
(
kwargs
)
super
(
ReZeroTransformer
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
# Deprecation warning.
if
output_range
is
not
None
:
logging
.
warning
(
"`output_range` is avaliable as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated."
)
self
.
_num_heads
=
num_attention_heads
self
.
_num_heads
=
num_attention_heads
self
.
_in
t
er
mediate_size
=
in
t
er
mediate_size
self
.
_in
n
er
_dim
=
in
n
er
_dim
self
.
_in
t
er
mediate
_activation
=
in
t
er
mediate
_activation
self
.
_in
n
er_activation
=
in
n
er_activation
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_dropout_rate
=
dropout_rate
self
.
_dropout_rate
=
dropout_rate
self
.
_output_range
=
output_range
self
.
_output_range
=
output_range
...
@@ -121,8 +133,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -121,8 +133,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
...
@@ -133,6 +143,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -133,6 +143,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim
=
self
.
_attention_head_size
,
key_dim
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
dropout
=
self
.
_attention_dropout_rate
,
name
=
"self_attention"
,
name
=
"self_attention"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
...
@@ -144,11 +156,13 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -144,11 +156,13 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_in
t
er
mediate_size
),
output_shape
=
(
None
,
self
.
_in
n
er
_dim
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
name
=
"intermediate"
,
name
=
"intermediate"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
if
policy
.
name
==
"mixed_bfloat16"
:
...
@@ -156,13 +170,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -156,13 +170,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
# as well, so we use float32.
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
# TODO(b/154538392): Investigate this.
policy
=
tf
.
float32
policy
=
tf
.
float32
self
.
_in
t
er
mediate
_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_in
n
er_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_in
t
er
mediate
_activation
,
dtype
=
policy
)
self
.
_in
n
er_activation
,
dtype
=
policy
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
name
=
"output"
,
name
=
"output"
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
),
bias_initializer
=
tf_utils
.
clone_initializer
(
self
.
_bias_initializer
),
**
common_kwargs
)
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
...
@@ -185,16 +201,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -185,16 +201,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable
=
True
,
trainable
=
True
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
super
(
ReZeroTransformer
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
config
=
{
"num_attention_heads"
:
"num_attention_heads"
:
self
.
_num_heads
,
self
.
_num_heads
,
"in
t
er
mediate_size
"
:
"in
n
er
_dim
"
:
self
.
_in
t
er
mediate_size
,
self
.
_in
n
er
_dim
,
"in
t
er
mediate
_activation"
:
"in
n
er_activation"
:
self
.
_in
t
er
mediate
_activation
,
self
.
_in
n
er_activation
,
"dropout_rate"
:
"dropout_rate"
:
self
.
_dropout_rate
,
self
.
_dropout_rate
,
"attention_dropout_rate"
:
"attention_dropout_rate"
:
...
@@ -220,7 +236,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -220,7 +236,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"bias_constraint"
:
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
}
}
base_config
=
super
(
ReZeroTransformer
,
self
).
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
reset_rezero
(
self
):
def
reset_rezero
(
self
):
...
@@ -228,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -228,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
if
not
self
.
_share_rezero
:
if
not
self
.
_share_rezero
:
self
.
_rezero_a_ffn
.
assign
(
0.
)
self
.
_rezero_a_ffn
.
assign
(
0.
)
def
call
(
self
,
inputs
)
:
def
call
(
self
,
inputs
,
output_range
:
Optional
[
tf
.
Tensor
]
=
None
)
->
tf
.
Tensor
:
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
input_tensor
,
attention_mask
=
inputs
...
@@ -241,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -241,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
else
:
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
output_range
is
None
:
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
output_range
=
self
.
_output_range
if
output_range
:
target_tensor
=
input_tensor
[:,
0
:
output_range
,
:]
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_
output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
output_range
,
:]
else
:
else
:
target_tensor
=
input_tensor
target_tensor
=
input_tensor
...
@@ -261,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -261,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output
=
tf
.
cast
(
attention_output
,
tf
.
float32
)
attention_output
=
tf
.
cast
(
attention_output
,
tf
.
float32
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
=
self
.
_inner_activation_layer
(
intermediate_output
)
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# During mixed precision training, attention_output is from layer norm and
# During mixed precision training, attention_output is from layer norm and
...
...
official/nlp/modeling/layers/rezero_transformer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
...
@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
output_tensor
=
test_layer
([
input_data
,
mask_data
],
output_range
=
1
)
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
,
atol
=
5e-5
,
rtol
=
0.003
)
def
test_separate_qkv
(
self
):
def
test_separate_qkv
(
self
):
test_layer
=
rezero_transformer
.
ReZeroTransformer
(
test_layer
=
rezero_transformer
.
ReZeroTransformer
(
num_attention_heads
=
2
,
num_attention_heads
=
2
,
...
...
official/nlp/modeling/layers/routing.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for Mixture of Experts (MoE) routing.
For MoE routing, we need to separate a set of tokens to sets of tokens.
Later on, different sets of tokens can potentially go to different experts.
"""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
TokenImportanceWithMovingAvg
(
tf
.
keras
.
layers
.
Layer
):
"""Routing based on per-token importance value."""
def
__init__
(
self
,
vocab_size
,
init_importance
,
moving_average_beta
=
0.995
,
**
kwargs
):
self
.
_vocab_size
=
vocab_size
self
.
_init_importance
=
init_importance
self
.
_moving_average_beta
=
moving_average_beta
super
().
__init__
(
**
kwargs
)
def
build
(
self
,
input_shape
):
self
.
_importance_embedding
=
self
.
add_weight
(
name
=
"importance_embed"
,
shape
=
(
self
.
_vocab_size
),
initializer
=
tf
.
keras
.
initializers
.
Constant
(
self
.
_init_importance
),
trainable
=
False
)
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"init_importance"
:
self
.
_init_importance
,
"moving_average_beta"
:
self
.
_moving_average_beta
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
update_token_importance
(
self
,
token_ids
,
importance
):
token_ids
=
tf
.
reshape
(
token_ids
,
shape
=
[
-
1
])
importance
=
tf
.
reshape
(
importance
,
shape
=
[
-
1
])
beta
=
self
.
_moving_average_beta
old_importance
=
tf
.
gather
(
self
.
_importance_embedding
,
token_ids
)
self
.
_importance_embedding
.
assign
(
tf
.
tensor_scatter_nd_update
(
self
.
_importance_embedding
,
tf
.
expand_dims
(
token_ids
,
axis
=
1
),
old_importance
*
beta
+
tf
.
cast
(
importance
*
(
1.0
-
beta
),
dtype
=
tf
.
float32
)))
def
call
(
self
,
inputs
):
return
tf
.
gather
(
self
.
_importance_embedding
,
inputs
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
SelectTopK
(
tf
.
keras
.
layers
.
Layer
):
"""Select top-k + random-k tokens according to importance."""
def
__init__
(
self
,
top_k
=
None
,
random_k
=
None
,
**
kwargs
):
self
.
_top_k
=
top_k
self
.
_random_k
=
random_k
super
().
__init__
(
**
kwargs
)
def
get_config
(
self
):
config
=
{
"top_k"
:
self
.
_top_k
,
"random_k"
:
self
.
_random_k
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
if
self
.
_random_k
is
None
:
# Pure top-k, not randomness.
pos
=
tf
.
argsort
(
inputs
,
direction
=
"DESCENDING"
)
selected
=
tf
.
slice
(
pos
,
[
0
,
0
],
[
-
1
,
self
.
_top_k
])
not_selected
=
tf
.
slice
(
pos
,
[
0
,
self
.
_top_k
],
[
-
1
,
-
1
])
elif
self
.
_top_k
is
None
:
# Pure randomness, no top-k.
pos
=
tf
.
argsort
(
tf
.
random
.
uniform
(
shape
=
tf
.
shape
(
inputs
)),
direction
=
"DESCENDING"
)
selected
=
tf
.
slice
(
pos
,
[
0
,
0
],
[
-
1
,
self
.
_random_k
])
not_selected
=
tf
.
slice
(
pos
,
[
0
,
self
.
_random_k
],
[
-
1
,
-
1
])
else
:
# Top-k plus randomness.
pos
=
tf
.
argsort
(
inputs
,
direction
=
"DESCENDING"
)
selected_top_k
=
tf
.
slice
(
pos
,
[
0
,
0
],
[
-
1
,
self
.
_top_k
])
pos_left
=
tf
.
slice
(
pos
,
[
0
,
self
.
_top_k
],
[
-
1
,
-
1
])
# Randomly shuffle pos_left
sort_index
=
tf
.
argsort
(
tf
.
random
.
uniform
(
shape
=
tf
.
shape
(
pos_left
)),
direction
=
"DESCENDING"
)
pos_left
=
tf
.
gather
(
pos_left
,
sort_index
,
batch_dims
=
1
,
axis
=
1
)
selected_rand
=
tf
.
slice
(
pos_left
,
[
0
,
0
],
[
-
1
,
self
.
_random_k
])
not_selected
=
tf
.
slice
(
pos_left
,
[
0
,
self
.
_random_k
],
[
-
1
,
-
1
])
selected
=
tf
.
concat
([
selected_top_k
,
selected_rand
],
axis
=
1
)
# Return the indices of selected and not-selected tokens.
return
selected
,
not_selected
official/nlp/modeling/layers/routing_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for routing."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
routing
class
TokenImportanceTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_token_importance
(
self
):
token_importance_embed
=
routing
.
TokenImportanceWithMovingAvg
(
vocab_size
=
4
,
init_importance
=
10.0
,
moving_average_beta
=
0.995
)
importance
=
token_importance_embed
(
np
.
array
([[
0
,
1
],
[
2
,
3
]]))
self
.
assertAllClose
(
importance
,
np
.
array
([[
10.0
,
10.0
],
[
10.0
,
10.0
]]))
token_importance_embed
.
update_token_importance
(
token_ids
=
np
.
array
([[
0
,
1
]]),
importance
=
np
.
array
([[
0.0
,
0.0
]]))
importance
=
token_importance_embed
(
np
.
array
([[
0
,
1
],
[
2
,
3
]]))
self
.
assertAllClose
(
importance
,
np
.
array
([[
9.95
,
9.95
],
[
10.0
,
10.0
]]))
class
TopKSelectionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_top_k_selection
(
self
):
token_selection
=
routing
.
SelectTopK
(
top_k
=
2
)
selected
,
_
=
token_selection
(
np
.
array
([[
0
,
1
,
2
,
3
],
[
4
,
3
,
2
,
1
]]))
self
.
assertAllClose
(
selected
,
np
.
array
([[
3
,
2
],
[
0
,
1
]]))
def
test_random_k_selection
(
self
):
token_selection
=
routing
.
SelectTopK
(
random_k
=
2
)
selected
,
_
=
token_selection
(
np
.
array
([[
0
,
1
,
2
,
3
],
[
4
,
3
,
2
,
1
]]))
self
.
assertAllClose
(
selected
.
shape
,
(
2
,
2
))
def
test_top_k_random_k
(
self
):
token_selection
=
routing
.
SelectTopK
(
top_k
=
1
,
random_k
=
1
)
selected
,
_
=
token_selection
(
np
.
array
([[
0
,
1
,
2
,
3
],
[
4
,
3
,
2
,
1
]]))
self
.
assertAllClose
(
selected
.
shape
,
(
2
,
2
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/layers/self_attention_mask.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,10 +13,38 @@
...
@@ -13,10 +13,38 @@
# limitations under the License.
# limitations under the License.
"""Keras layer that creates a self-attention mask."""
"""Keras layer that creates a self-attention mask."""
from
typing
import
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
def
get_mask
(
inputs
:
tf
.
Tensor
,
to_mask
:
tf
.
Tensor
,
dtype
:
Optional
[
tf
.
DType
]
=
None
)
->
tf
.
Tensor
:
"""Gets a 3D self-attention mask.
Args:
inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length,
...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
dtype: the output Tensor dtype.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape
=
tf
.
shape
(
inputs
)
batch_size
=
from_shape
[
0
]
from_seq_length
=
from_shape
[
1
]
dtype
=
inputs
.
dtype
if
dtype
is
None
else
dtype
to_shape
=
tf
.
shape
(
to_mask
)
to_seq_length
=
to_shape
[
1
]
to_mask
=
tf
.
cast
(
tf
.
reshape
(
to_mask
,
[
batch_size
,
1
,
to_seq_length
]),
dtype
=
dtype
)
return
tf
.
broadcast_to
(
to_mask
,
[
batch_size
,
from_seq_length
,
to_seq_length
])
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
SelfAttentionMask
(
tf
.
keras
.
layers
.
Layer
):
class
SelfAttentionMask
(
tf
.
keras
.
layers
.
Layer
):
"""Create 3D attention mask from a 2D tensor mask.
"""Create 3D attention mask from a 2D tensor mask.
...
@@ -33,26 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer):
...
@@ -33,26 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer):
if
isinstance
(
inputs
,
list
)
and
to_mask
is
None
:
if
isinstance
(
inputs
,
list
)
and
to_mask
is
None
:
to_mask
=
inputs
[
1
]
to_mask
=
inputs
[
1
]
inputs
=
inputs
[
0
]
inputs
=
inputs
[
0
]
from_shape
=
tf
.
shape
(
inputs
)
return
get_mask
(
inputs
,
to_mask
)
batch_size
=
from_shape
[
0
]
from_seq_length
=
from_shape
[
1
]
to_shape
=
tf
.
shape
(
to_mask
)
to_seq_length
=
to_shape
[
1
]
to_mask
=
tf
.
cast
(
tf
.
reshape
(
to_mask
,
[
batch_size
,
1
,
to_seq_length
]),
dtype
=
inputs
.
dtype
)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones
=
tf
.
ones
(
shape
=
[
batch_size
,
from_seq_length
,
1
],
dtype
=
inputs
.
dtype
)
# Here we broadcast along two dimensions to create the mask.
mask
=
broadcast_ones
*
to_mask
return
mask
official/nlp/modeling/layers/spectral_normalization.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -74,21 +74,20 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
...
@@ -74,21 +74,20 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
if
not
isinstance
(
layer
,
tf
.
keras
.
layers
.
Layer
):
if
not
isinstance
(
layer
,
tf
.
keras
.
layers
.
Layer
):
raise
ValueError
(
'`layer` must be a `tf.keras.layer.Layer`. '
raise
ValueError
(
'`layer` must be a `tf.keras.layer.Layer`. '
'Observed `{}`'
.
format
(
layer
))
'Observed `{}`'
.
format
(
layer
))
super
(
SpectralNormalization
,
self
).
__init__
(
super
().
__init__
(
layer
,
name
=
wrapper_name
,
**
kwargs
)
layer
,
name
=
wrapper_name
,
**
kwargs
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
super
(
SpectralNormalization
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
self
.
layer
.
kernel
.
_aggregation
=
self
.
aggregation
# pylint: disable=protected-access
self
.
layer
.
kernel
.
_aggregation
=
self
.
aggregation
# pylint: disable=protected-access
self
.
_dtype
=
self
.
layer
.
kernel
.
dtype
self
.
_dtype
=
self
.
layer
.
kernel
.
dtype
self
.
w
=
self
.
layer
.
kernel
self
.
w
=
self
.
layer
.
kernel
self
.
w_shape
=
self
.
w
.
shape
.
as_list
()
self
.
w_shape
=
self
.
w
.
shape
.
as_list
()
self
.
uv_initializer
=
tf
.
initializers
.
random_normal
()
self
.
v
=
self
.
add_weight
(
self
.
v
=
self
.
add_weight
(
shape
=
(
1
,
np
.
prod
(
self
.
w_shape
[:
-
1
])),
shape
=
(
1
,
np
.
prod
(
self
.
w_shape
[:
-
1
])),
initializer
=
self
.
uv_
initializer
,
initializer
=
tf
.
initializer
s
.
random_normal
()
,
trainable
=
False
,
trainable
=
False
,
name
=
'v'
,
name
=
'v'
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
...
@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self
.
u
=
self
.
add_weight
(
self
.
u
=
self
.
add_weight
(
shape
=
(
1
,
self
.
w_shape
[
-
1
]),
shape
=
(
1
,
self
.
w_shape
[
-
1
]),
initializer
=
self
.
uv_
initializer
,
initializer
=
tf
.
initializer
s
.
random_normal
()
,
trainable
=
False
,
trainable
=
False
,
name
=
'u'
,
name
=
'u'
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -194,10 +193,11 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
...
@@ -194,10 +193,11 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
raise
ValueError
(
raise
ValueError
(
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
.
format
(
input
=
layer
))
.
format
(
input
=
layer
))
super
(
SpectralNormalizationConv2D
,
self
).
__init__
(
layer
,
**
kwargs
)
super
().
__init__
(
layer
,
**
kwargs
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
self
.
layer
.
build
(
input_shape
)
if
not
self
.
layer
.
built
:
self
.
layer
.
build
(
input_shape
)
self
.
layer
.
kernel
.
_aggregation
=
self
.
aggregation
# pylint: disable=protected-access
self
.
layer
.
kernel
.
_aggregation
=
self
.
aggregation
# pylint: disable=protected-access
self
.
_dtype
=
self
.
layer
.
kernel
.
dtype
self
.
_dtype
=
self
.
layer
.
kernel
.
dtype
...
@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
...
@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self
.
in_shape
=
(
uv_dim
,
in_height
,
in_width
,
in_channel
)
self
.
in_shape
=
(
uv_dim
,
in_height
,
in_width
,
in_channel
)
self
.
out_shape
=
(
uv_dim
,
out_height
,
out_width
,
out_channel
)
self
.
out_shape
=
(
uv_dim
,
out_height
,
out_width
,
out_channel
)
self
.
uv_initializer
=
tf
.
initializers
.
random_normal
()
self
.
v
=
self
.
add_weight
(
self
.
v
=
self
.
add_weight
(
shape
=
self
.
in_shape
,
shape
=
self
.
in_shape
,
initializer
=
self
.
uv_
initializer
,
initializer
=
tf
.
initializer
s
.
random_normal
()
,
trainable
=
False
,
trainable
=
False
,
name
=
'v'
,
name
=
'v'
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -233,13 +232,13 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
...
@@ -233,13 +232,13 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self
.
u
=
self
.
add_weight
(
self
.
u
=
self
.
add_weight
(
shape
=
self
.
out_shape
,
shape
=
self
.
out_shape
,
initializer
=
self
.
uv_
initializer
,
initializer
=
tf
.
initializer
s
.
random_normal
()
,
trainable
=
False
,
trainable
=
False
,
name
=
'u'
,
name
=
'u'
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
aggregation
=
self
.
aggregation
)
aggregation
=
self
.
aggregation
)
super
(
SpectralNormalizationConv2D
,
self
).
build
()
super
().
build
()
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
u_update_op
,
v_update_op
,
w_update_op
=
self
.
update_weights
()
u_update_op
,
v_update_op
,
w_update_op
=
self
.
update_weights
()
...
...
official/nlp/modeling/layers/spectral_normalization_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed
=
_compute_spectral_norm
(
normalized_kernel
)
spectral_norm_computed
=
_compute_spectral_norm
(
normalized_kernel
)
spectral_norm_expected
=
self
.
norm_multiplier
spectral_norm_expected
=
self
.
norm_multiplier
self
.
assertAllClose
(
self
.
assertAllClose
(
spectral_norm_computed
,
spectral_norm_expected
,
atol
=
5
e-
2
)
spectral_norm_computed
,
spectral_norm_expected
,
atol
=
1
e-
1
)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
...
...
official/nlp/modeling/layers/talking_heads_attention.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -20,6 +20,8 @@ import string
...
@@ -20,6 +20,8 @@ import string
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
_CHR_IDX
=
string
.
ascii_lowercase
_CHR_IDX
=
string
.
ascii_lowercase
...
@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self
.
_pre_softmax_weight
=
self
.
add_weight
(
self
.
_pre_softmax_weight
=
self
.
add_weight
(
"pre_softmax_weight"
,
"pre_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
initializer
=
self
.
_kernel_initializer
,
initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
)
,
regularizer
=
self
.
_kernel_regularizer
,
regularizer
=
self
.
_kernel_regularizer
,
constraint
=
self
.
_kernel_constraint
,
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self
.
_post_softmax_weight
=
self
.
add_weight
(
self
.
_post_softmax_weight
=
self
.
add_weight
(
"post_softmax_weight"
,
"post_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
initializer
=
self
.
_kernel_initializer
,
initializer
=
tf_utils
.
clone_initializer
(
self
.
_kernel_initializer
)
,
regularizer
=
self
.
_kernel_regularizer
,
regularizer
=
self
.
_kernel_regularizer
,
constraint
=
self
.
_kernel_constraint
,
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
...
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/layers/text_layers.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,14 +14,16 @@
...
@@ -14,14 +14,16 @@
"""Keras Layers for BERT-specific preprocessing."""
"""Keras Layers for BERT-specific preprocessing."""
# pylint: disable=g-import-not-at-top
# pylint: disable=g-import-not-at-top
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Text
,
Union
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
try
:
try
:
# pytype: disable=import-error
import
tensorflow_text
as
text
import
tensorflow_text
as
text
from
tensorflow_text.python.ops
import
bert_tokenizer
from
tensorflow_text.python.ops
import
bert_tokenizer
# pytype: enable=import-error
except
ImportError
:
except
ImportError
:
text
=
None
text
=
None
bert_tokenizer
=
None
bert_tokenizer
=
None
...
@@ -57,7 +59,7 @@ def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
...
@@ -57,7 +59,7 @@ def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
class
BertTokenizer
(
tf
.
keras
.
layers
.
Layer
):
class
BertTokenizer
(
tf
.
keras
.
layers
.
Layer
):
"""Wraps BertTokenizer with pre-defined vocab as a Keras Layer.
"""Wraps
TF.Text's
BertTokenizer with pre-defined vocab as a Keras Layer.
Attributes:
Attributes:
tokenize_with_offsets: If true, calls
tokenize_with_offsets: If true, calls
...
@@ -71,8 +73,9 @@ class BertTokenizer(tf.keras.layers.Layer):
...
@@ -71,8 +73,9 @@ class BertTokenizer(tf.keras.layers.Layer):
def
__init__
(
self
,
*
,
def
__init__
(
self
,
*
,
vocab_file
:
str
,
vocab_file
:
str
,
lower_case
:
bool
,
lower_case
:
Optional
[
bool
]
=
None
,
tokenize_with_offsets
:
bool
=
False
,
tokenize_with_offsets
:
bool
=
False
,
tokenizer_kwargs
:
Optional
[
Mapping
[
Text
,
Any
]]
=
None
,
**
kwargs
):
**
kwargs
):
"""Initialize a `BertTokenizer` layer.
"""Initialize a `BertTokenizer` layer.
...
@@ -81,15 +84,18 @@ class BertTokenizer(tf.keras.layers.Layer):
...
@@ -81,15 +84,18 @@ class BertTokenizer(tf.keras.layers.Layer):
This is a text file with newline-separated wordpiece tokens.
This is a text file with newline-separated wordpiece tokens.
This layer initializes a lookup table from it that gets used with
This layer initializes a lookup table from it that gets used with
`text.BertTokenizer`.
`text.BertTokenizer`.
lower_case:
A Python
boolean forwarded to `text.BertTokenizer`.
lower_case:
Optional
boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable)
If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which
before tokenization. This must be set to match the way in which
the `vocab_file` was created.
the `vocab_file` was created. If passed, this overrides whatever value
may have been passed in `tokenizer_kwargs`.
tokenize_with_offsets: A Python boolean. If true, this layer calls
tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of
`text.BertTokenizer.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)`
`(tokens, start_offsets, limit_offsets)`
insead of just tokens.
insead of just tokens.
tokenizer_kwargs: Optional mapping with keyword arguments to forward to
`text.BertTokenizer`'s constructor.
**kwargs: Standard arguments to `Layer()`.
**kwargs: Standard arguments to `Layer()`.
Raises:
Raises:
...
@@ -111,8 +117,11 @@ class BertTokenizer(tf.keras.layers.Layer):
...
@@ -111,8 +117,11 @@ class BertTokenizer(tf.keras.layers.Layer):
self
.
_special_tokens_dict
=
self
.
_create_special_tokens_dict
(
self
.
_special_tokens_dict
=
self
.
_create_special_tokens_dict
(
self
.
_vocab_table
,
vocab_file
)
self
.
_vocab_table
,
vocab_file
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_bert_tokenizer
=
text
.
BertTokenizer
(
tokenizer_kwargs
=
dict
(
tokenizer_kwargs
or
{})
self
.
_vocab_table
,
lower_case
=
lower_case
)
if
lower_case
is
not
None
:
tokenizer_kwargs
[
"lower_case"
]
=
lower_case
self
.
_bert_tokenizer
=
text
.
BertTokenizer
(
self
.
_vocab_table
,
**
tokenizer_kwargs
)
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
...
...
Prev
1
…
17
18
19
20
21
22
23
24
25
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment