Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
76f760c4
Commit
76f760c4
authored
Jul 30, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #9005 from xinliupitt:master
PiperOrigin-RevId: 324102453
parents
2e785497
673231ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
201 additions
and
20 deletions
+201
-20
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+2
-2
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+123
-14
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+76
-4
No files found.
official/nlp/modeling/layers/attention.py
View file @
76f760c4
...
@@ -521,11 +521,11 @@ class CachedAttention(MultiHeadAttention):
...
@@ -521,11 +521,11 @@ class CachedAttention(MultiHeadAttention):
if
cache
:
if
cache
:
key
,
value
=
self
.
_update_cache
(
key
,
value
,
cache
,
decode_loop_step
)
key
,
value
=
self
.
_update_cache
(
key
,
value
,
cache
,
decode_loop_step
)
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T]
# `attention_scores` = [B, N, F, T]
...
...
official/nlp/modeling/layers/transformer.py
View file @
76f760c4
...
@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
**
kwargs
):
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
...
@@ -81,6 +90,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -81,6 +90,9 @@ class Transformer(tf.keras.layers.Layer):
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
...
@@ -117,6 +129,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -117,6 +129,7 @@ class Transformer(tf.keras.layers.Layer):
num_heads
=
self
.
_num_heads
,
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
dropout
=
self
.
_attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
name
=
"self_attention"
,
name
=
"self_attention"
,
**
common_kwargs
)
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
...
@@ -126,7 +139,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -126,7 +139,7 @@ class Transformer(tf.keras.layers.Layer):
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
...
@@ -151,7 +164,10 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -151,7 +164,10 @@ class Transformer(tf.keras.layers.Layer):
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
super
(
Transformer
,
self
).
build
(
input_shape
)
super
(
Transformer
,
self
).
build
(
input_shape
)
...
@@ -182,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -182,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint"
:
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
}
}
base_config
=
super
(
Transformer
,
self
).
get_config
()
base_config
=
super
(
Transformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
@@ -197,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -197,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_output
=
self
.
_attention_layer
(
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
if
self
.
_norm_first
:
attention_output
)
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
intermediate_output
)
...
@@ -213,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -213,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
# add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
return
layer_output
...
@@ -251,6 +285,12 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -251,6 +285,12 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -267,6 +307,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -267,6 +307,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
**
kwargs
):
**
kwargs
):
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
...
@@ -283,6 +326,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -283,6 +326,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
else
:
else
:
...
@@ -312,6 +358,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -312,6 +358,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
num_heads
=
self
.
num_attention_heads
,
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
dropout
=
self
.
attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
name
=
"self_attention"
,
name
=
"self_attention"
,
**
common_kwargs
)
**
common_kwargs
)
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
...
@@ -324,13 +371,16 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -324,13 +371,16 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate
=
self
.
dropout_rate
)
rate
=
self
.
dropout_rate
)
self
.
self_attention_layer_norm
=
(
self
.
self_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
))
# Encoder-decoder attention.
# Encoder-decoder attention.
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
num_heads
=
self
.
num_attention_heads
,
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
dropout
=
self
.
attention_dropout_rate
,
output_shape
=
hidden_size
,
output_shape
=
hidden_size
,
use_bias
=
self
.
_use_bias
,
name
=
"attention/encdec"
,
name
=
"attention/encdec"
,
**
common_kwargs
)
**
common_kwargs
)
...
@@ -338,7 +388,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -338,7 +388,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate
=
self
.
dropout_rate
)
rate
=
self
.
dropout_rate
)
self
.
encdec_attention_layer_norm
=
(
self
.
encdec_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
))
# Feed-forward projection.
# Feed-forward projection.
self
.
intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
...
@@ -357,9 +409,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -357,9 +409,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**
common_kwargs
)
**
common_kwargs
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
)
super
(
TransformerDecoderLayer
,
self
).
build
(
input_shape
)
super
(
TransformerDecoderLayer
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_attention_heads"
:
self
.
num_attention_heads
,
"intermediate_size"
:
self
.
intermediate_size
,
"intermediate_activation"
:
self
.
intermediate_activation
,
"dropout_rate"
:
self
.
dropout_rate
,
"attention_dropout_rate"
:
self
.
attention_dropout_rate
,
"multi_channel_cross_attention"
:
self
.
multi_channel_cross_attention
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
}
base_config
=
super
(
TransformerDecoderLayer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
common_layers_with_encoder
(
self
):
def
common_layers_with_encoder
(
self
):
"""Gets layer objects that can make a Transformer encoder block."""
"""Gets layer objects that can make a Transformer encoder block."""
return
[
return
[
...
@@ -378,6 +468,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -378,6 +468,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
len
(
inputs
))
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
source_tensor
=
input_tensor
if
self
.
_norm_first
:
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_output
,
cache
=
self
.
self_attention
(
query
=
input_tensor
,
query
=
input_tensor
,
value
=
input_tensor
,
value
=
input_tensor
,
...
@@ -385,8 +478,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -385,8 +478,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cache
=
cache
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
if
self
.
_norm_first
:
input_tensor
+
self_attention_output
)
self_attention_output
=
source_tensor
+
self_attention_output
else
:
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
if
self
.
_norm_first
:
source_self_attention_output
=
self_attention_output
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
)
cross_attn_inputs
=
dict
(
cross_attn_inputs
=
dict
(
query
=
self_attention_output
,
query
=
self_attention_output
,
value
=
memory
,
value
=
memory
,
...
@@ -396,13 +496,22 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -396,13 +496,22 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
if
self
.
_norm_first
:
attention_output
)
attention_output
=
source_self_attention_output
+
attention_output
else
:
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
output_layer_norm
(
attention_output
)
intermediate_output
=
self
.
intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
intermediate_activation_layer
(
intermediate_output
=
self
.
intermediate_activation_layer
(
intermediate_output
)
intermediate_output
)
layer_output
=
self
.
output_dense
(
intermediate_output
)
layer_output
=
self
.
output_dense
(
intermediate_output
)
layer_output
=
self
.
output_dropout
(
layer_output
)
layer_output
=
self
.
output_dropout
(
layer_output
)
layer_output
=
self
.
output_layer_norm
(
layer_output
+
attention_output
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
,
cache
return
layer_output
,
cache
official/nlp/modeling/layers/transformer_test.py
View file @
76f760c4
...
@@ -152,10 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -152,10 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_
=
new_layer
([
input_data
,
mask_data
])
_
=
new_layer
([
input_data
,
mask_data
])
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
self
.
assertAllClose
(
output_tensor
[:,
0
:
1
,
:],
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
...
@@ -218,6 +216,45 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -218,6 +216,45 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
([
1
,
input_length
,
width
],
output_data
.
shape
)
self
.
assertAllEqual
([
1
,
input_length
,
width
],
output_data
.
shape
)
@
keras_parameterized
.
run_all_keras_modes
class
TransformerArgumentTest
(
keras_parameterized
.
TestCase
):
def
test_use_bias_norm_first
(
self
):
num_attention_heads
=
2
hidden_size
=
16
encoder_block
=
transformer
.
Transformer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
inputs
=
[
dummy_tensor
,
dummy_mask
]
output
=
encoder_block
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
def
test_get_config
(
self
):
num_attention_heads
=
2
encoder_block
=
transformer
.
Transformer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
encoder_block_config
=
encoder_block
.
get_config
()
new_encoder_block
=
transformer
.
Transformer
.
from_config
(
encoder_block_config
)
self
.
assertEqual
(
encoder_block_config
,
new_encoder_block
.
get_config
())
def
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
):
def
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
):
return
{
return
{
'key'
:
'key'
:
...
@@ -251,6 +288,41 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
...
@@ -251,6 +288,41 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
self
.
assertEqual
(
cache
[
'value'
].
shape
,
(
2
,
4
,
2
,
8
))
self
.
assertEqual
(
cache
[
'value'
].
shape
,
(
2
,
4
,
2
,
8
))
def
test_use_bias_norm_first
(
self
):
num_attention_heads
=
2
hidden_size
=
16
decoder_block
=
transformer
.
TransformerDecoderLayer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
inputs
=
[
dummy_tensor
,
dummy_tensor
,
dummy_mask
,
dummy_mask
]
output
,
_
=
decoder_block
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
def
test_get_config
(
self
):
num_attention_heads
=
2
decoder_block
=
transformer
.
TransformerDecoderLayer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
decoder_block_config
=
decoder_block
.
get_config
()
new_decoder_block
=
transformer
.
TransformerDecoderLayer
.
from_config
(
decoder_block_config
)
self
.
assertEqual
(
decoder_block_config
,
new_decoder_block
.
get_config
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment