Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d3d4177d
Commit
d3d4177d
authored
Sep 10, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 395996268
parent
a173db62
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
19 deletions
+93
-19
official/nlp/keras_nlp/layers/transformer_encoder_block.py
official/nlp/keras_nlp/layers/transformer_encoder_block.py
+19
-2
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+59
-12
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+15
-5
No files found.
official/nlp/keras_nlp/layers/transformer_encoder_block.py
View file @
d3d4177d
...
@@ -116,6 +116,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -116,6 +116,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
self
.
_attention_axes
=
attention_axes
def
_maybe_build
(
self
,
inputs
):
super
().
_maybe_build
(
inputs
[:
1
])
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
input_tensor_shape
=
input_shape
...
@@ -247,6 +250,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -247,6 +250,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
[`query tensor`, `key value tensor`, `attention mask`] to have separate
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
input streams for the query, and key/value to the multi-head
attention.
attention.
[`query tensor`, `key value tensor`, `attention mask`, `pos_embed`] to
have an additional pos_embed that is added to the query and key of
every self-attention layer.
Returns:
Returns:
An output tensor with the same dimensions as input/query tensor.
An output tensor with the same dimensions as input/query tensor.
...
@@ -255,13 +261,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -255,13 +261,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
len
(
inputs
)
==
2
:
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
input_tensor
,
attention_mask
=
inputs
key_value
=
None
key_value
=
None
pos_embed
=
None
elif
len
(
inputs
)
==
3
:
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
input_tensor
,
key_value
,
attention_mask
=
inputs
pos_embed
=
None
elif
len
(
inputs
)
==
4
:
input_tensor
,
key_value
,
attention_mask
,
pos_embed
=
inputs
else
:
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
(
self
.
__class__
,
len
(
inputs
)))
else
:
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
input_tensor
,
key_value
,
attention_mask
,
pos_embed
=
(
inputs
,
None
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_output_range
:
if
self
.
_norm_first
:
if
self
.
_norm_first
:
...
@@ -282,8 +293,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -282,8 +293,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
key_value
is
None
:
if
key_value
is
None
:
key_value
=
input_tensor
key_value
=
input_tensor
if
pos_embed
is
None
:
query
=
target_tensor
key
=
key_value
else
:
query
=
target_tensor
+
pos_embed
key
=
key_value
+
pos_embed
attention_output
=
self
.
_attention_layer
(
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
query
=
query
,
key
=
key
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
attention_output
=
source_tensor
+
attention_output
...
...
official/nlp/modeling/layers/transformer.py
View file @
d3d4177d
...
@@ -232,6 +232,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -232,6 +232,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
else
:
else
:
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
def
_maybe_build
(
self
,
inputs
):
super
().
_maybe_build
(
inputs
[:
1
])
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
target_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
target_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
if
len
(
target_tensor_shape
.
as_list
())
!=
3
:
if
len
(
target_tensor_shape
.
as_list
())
!=
3
:
...
@@ -370,22 +373,57 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -370,22 +373,57 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self
.
intermediate_dense
,
self
.
output_dense
,
self
.
output_layer_norm
self
.
intermediate_dense
,
self
.
output_dense
,
self
.
output_layer_norm
]
]
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
N
on
e
):
def
_parse_inputs
(
self
,
inputs
,
multi_channel_cross_attenti
on
):
if
self
.
multi_channel_cross_attention
:
if
multi_channel_cross_attention
:
if
len
(
inputs
)
!=
5
:
if
len
(
inputs
)
<
5
:
raise
ValueError
(
raise
ValueError
(
"TransformerDecoderBlock must have 5 inputs, when it uses "
"TransformerDecoderBlock must have
at least
5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d"
%
len
(
inputs
))
"multi_channel_cross_attention. But it got: %d"
%
len
(
inputs
))
elif
len
(
inputs
)
!=
4
:
elif
len
(
inputs
)
==
5
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
=
inputs
input_pos_embed
=
None
memory_pos_embed
=
None
elif
len
(
inputs
)
==
6
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
=
inputs
memory_pos_embed
=
None
else
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
=
inputs
[:
7
]
else
:
context_attention_weights
=
None
if
len
(
inputs
)
<
4
:
raise
ValueError
(
raise
ValueError
(
"TransformerDecoderBlock must have 4 inputs, but it got: %d"
%
"TransformerDecoderBlock must have at leaset 4 inputs, but it "
len
(
inputs
))
"got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
elif
len
(
inputs
)
==
4
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
input_pos_embed
=
None
memory_pos_embed
=
None
elif
len
(
inputs
)
==
5
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
=
inputs
memory_pos_embed
=
None
else
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
=
inputs
[:
6
]
return
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
=
self
.
_parse_inputs
(
inputs
,
self
.
multi_channel_cross_attention
)
source_tensor
=
input_tensor
source_tensor
=
input_tensor
if
self
.
_norm_first
:
if
self
.
_norm_first
:
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
if
input_pos_embed
is
None
:
self_attn_query
=
input_tensor
self_attn_key
=
input_tensor
else
:
self_attn_query
=
input_tensor
+
input_pos_embed
self_attn_key
=
input_tensor
+
input_pos_embed
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_output
,
cache
=
self
.
self_attention
(
query
=
input_tensor
,
query
=
self_attn_query
,
key
=
self_attn_key
,
value
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
cache
=
cache
,
...
@@ -400,13 +438,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -400,13 +438,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
source_self_attention_output
=
self_attention_output
source_self_attention_output
=
self_attention_output
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
)
self_attention_output
)
if
input_pos_embed
is
None
:
cross_attn_query
=
self_attention_output
else
:
cross_attn_query
=
self_attention_output
+
input_pos_embed
if
memory_pos_embed
is
None
:
cross_attn_key
=
memory
else
:
cross_attn_key
=
memory
+
memory_pos_embed
cross_attn_inputs
=
dict
(
cross_attn_inputs
=
dict
(
query
=
self_attention_output
,
query
=
cross_attn_query
,
key
=
cross_attn_key
,
value
=
memory
,
value
=
memory
,
attention_mask
=
attention_mask
)
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
cross_attn_inputs
[
"context_attention_weights"
]
=
context_attention_weights
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
if
self
.
_norm_first
:
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
d3d4177d
...
@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
...
@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
base_config
=
super
(
TransformerEncoder
,
self
).
get_config
()
base_config
=
super
(
TransformerEncoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
):
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
,
pos_embed
=
None
):
"""Return the output of the encoder.
"""Return the output of the encoder.
Args:
Args:
...
@@ -433,14 +433,17 @@ class TransformerEncoder(tf.keras.layers.Layer):
...
@@ -433,14 +433,17 @@ class TransformerEncoder(tf.keras.layers.Layer):
hidden_size)`.
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
`(batch_size, input_length, input_length)`.
pos_embed: A tensor or a float that is added to the query and key of every
self-attention layer. Defaults to None.
Returns:
Returns:
Output of encoder which is a `float32` tensor with shape
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
`(batch_size, input_length, hidden_size)`.
"""
"""
for
layer_idx
in
range
(
self
.
num_layers
):
for
layer_idx
in
range
(
self
.
num_layers
):
encoder_inputs
=
self
.
encoder_layers
[
layer_idx
](
encoder_inputs
=
self
.
encoder_layers
[
layer_idx
](
[
encoder_inputs
,
attention_mask
])
[
encoder_inputs
,
encoder_inputs
,
attention_mask
,
pos_embed
])
output_tensor
=
encoder_inputs
output_tensor
=
encoder_inputs
output_tensor
=
self
.
output_normalization
(
output_tensor
)
output_tensor
=
self
.
output_normalization
(
output_tensor
)
...
@@ -519,7 +522,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
...
@@ -519,7 +522,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
attention_initializer
=
attention_initializer
(
input_shape
[
2
]),
attention_initializer
=
attention_initializer
(
input_shape
[
2
]),
name
=
(
"layer_%d"
%
i
)))
name
=
(
"layer_%d"
%
i
)))
self
.
output_normalization
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
output_normalization
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-6
,
dtype
=
"float32"
)
epsilon
=
self
.
_norm_epsilon
,
dtype
=
"float32"
)
super
(
TransformerDecoder
,
self
).
build
(
input_shape
)
super
(
TransformerDecoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
...
@@ -545,7 +548,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
...
@@ -545,7 +548,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
cross_attention_mask
=
None
,
cross_attention_mask
=
None
,
cache
=
None
,
cache
=
None
,
decode_loop_step
=
None
,
decode_loop_step
=
None
,
return_all_decoder_outputs
=
False
):
return_all_decoder_outputs
=
False
,
input_pos_embed
=
None
,
memory_pos_embed
=
None
):
"""Return the output of the decoder layer stacks.
"""Return the output of the decoder layer stacks.
Args:
Args:
...
@@ -565,6 +570,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
...
@@ -565,6 +570,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
return_all_decoder_outputs: Return all decoder layer outputs.
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
This is useful when introducing per layer auxiliary loss.
input_pos_embed: A tensor or float that is added to the target embedding
in every self-attention and cross-attention layer. Defaults to None.
memory_pos_embed: A tensor or float that is added to the memory embedding
in every cross-attention layer. Defaults to None.
Returns:
Returns:
Output of decoder.
Output of decoder.
...
@@ -575,7 +584,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
...
@@ -575,7 +584,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
decoder_outputs
=
[]
decoder_outputs
=
[]
for
layer_idx
in
range
(
self
.
num_layers
):
for
layer_idx
in
range
(
self
.
num_layers
):
transformer_inputs
=
[
transformer_inputs
=
[
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
]
]
# Gets the cache for decoding.
# Gets the cache for decoding.
if
cache
is
None
:
if
cache
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment