Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
23804bc5
Commit
23804bc5
authored
Jul 29, 2020
by
xinliupitt
Browse files
transformer, attention layers
parent
bda18166
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
15 deletions
+64
-15
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+1
-2
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+63
-13
No files found.
official/nlp/modeling/layers/attention.py
View file @
23804bc5
...
...
@@ -523,9 +523,8 @@ class CachedAttention(MultiHeadAttention):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T]
...
...
official/nlp/modeling/layers/transformer.py
View file @
23804bc5
...
...
@@ -65,6 +65,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -81,6 +84,9 @@ class Transformer(tf.keras.layers.Layer):
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
...
...
@@ -117,6 +123,7 @@ class Transformer(tf.keras.layers.Layer):
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
name
=
"self_attention"
,
**
common_kwargs
)
# pylint: disable=protected-access
...
...
@@ -132,7 +139,7 @@ class Transformer(tf.keras.layers.Layer):
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
...
...
@@ -157,7 +164,8 @@ class Transformer(tf.keras.layers.Layer):
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
super
(
Transformer
,
self
).
build
(
input_shape
)
...
...
@@ -203,13 +211,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
target_tensor
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
...
...
@@ -219,7 +236,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
...
...
@@ -273,6 +293,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
**
kwargs
):
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
num_attention_heads
=
num_attention_heads
...
...
@@ -289,6 +312,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
if
self
.
multi_channel_cross_attention
:
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
else
:
...
...
@@ -318,6 +344,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
...
...
@@ -330,13 +357,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate
=
self
.
dropout_rate
)
self
.
self_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
))
# Encoder-decoder attention.
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
output_shape
=
hidden_size
,
use_bias
=
self
.
_use_bias
,
name
=
"attention/encdec"
,
**
common_kwargs
)
...
...
@@ -344,7 +373,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate
=
self
.
dropout_rate
)
self
.
encdec_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
))
# Feed-forward projection.
self
.
intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
...
...
@@ -363,7 +393,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**
common_kwargs
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
)
super
(
TransformerDecoderLayer
,
self
).
build
(
input_shape
)
def
common_layers_with_encoder
(
self
):
...
...
@@ -384,6 +414,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
source_tensor
=
input_tensor
if
self
.
_norm_first
:
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
self_attention_output
,
cache
=
self
.
self_attention
(
query
=
input_tensor
,
value
=
input_tensor
,
...
...
@@ -391,8 +424,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
if
self
.
_norm_first
:
self_attention_output
=
source_tensor
+
self_attention_output
else
:
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
if
self
.
_norm_first
:
source_self_attention_output
=
self_attention_output
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
)
cross_attn_inputs
=
dict
(
query
=
self_attention_output
,
value
=
memory
,
...
...
@@ -402,13 +442,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_self_attention_output
+
attention_output
else
:
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
output_layer_norm
(
attention_output
)
intermediate_output
=
self
.
intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
intermediate_activation_layer
(
intermediate_output
)
layer_output
=
self
.
output_dense
(
intermediate_output
)
layer_output
=
self
.
output_dropout
(
layer_output
)
layer_output
=
self
.
output_layer_norm
(
layer_output
+
attention_output
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
,
cache
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment