Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fa211938
Commit
fa211938
authored
Sep 10, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 396035361
parent
b0707104
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
93 deletions
+19
-93
official/nlp/keras_nlp/layers/transformer_encoder_block.py
official/nlp/keras_nlp/layers/transformer_encoder_block.py
+2
-19
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+12
-59
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+5
-15
No files found.
official/nlp/keras_nlp/layers/transformer_encoder_block.py
View file @
fa211938
...
...
@@ -116,9 +116,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
def
_maybe_build
(
self
,
inputs
):
super
().
_maybe_build
(
inputs
[:
1
])
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
...
...
@@ -250,9 +247,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
[`query tensor`, `key value tensor`, `attention mask`, `pos_embed`] to
have an additional pos_embed that is added to the query and key of
every self-attention layer.
Returns:
An output tensor with the same dimensions as input/query tensor.
...
...
@@ -261,18 +255,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
pos_embed
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
pos_embed
=
None
elif
len
(
inputs
)
==
4
:
input_tensor
,
key_value
,
attention_mask
,
pos_embed
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
,
pos_embed
=
(
inputs
,
None
,
None
,
None
)
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_norm_first
:
...
...
@@ -293,14 +282,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
key_value
is
None
:
key_value
=
input_tensor
if
pos_embed
is
None
:
query
=
target_tensor
key
=
key_value
else
:
query
=
target_tensor
+
pos_embed
key
=
key_value
+
pos_embed
attention_output
=
self
.
_attention_layer
(
query
=
query
,
key
=
key
,
value
=
key_value
,
attention_mask
=
attention_mask
)
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
...
...
official/nlp/modeling/layers/transformer.py
View file @
fa211938
...
...
@@ -232,9 +232,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
else
:
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
def
_maybe_build
(
self
,
inputs
):
super
().
_maybe_build
(
inputs
[:
1
])
def
build
(
self
,
input_shape
):
target_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
if
len
(
target_tensor_shape
.
as_list
())
!=
3
:
...
...
@@ -373,57 +370,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self
.
intermediate_dense
,
self
.
output_dense
,
self
.
output_layer_norm
]
def
_parse_inputs
(
self
,
inputs
,
multi_channel_cross_attenti
on
):
if
multi_channel_cross_attention
:
if
len
(
inputs
)
<
5
:
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
N
on
e
):
if
self
.
multi_channel_cross_attention
:
if
len
(
inputs
)
!=
5
:
raise
ValueError
(
"TransformerDecoderBlock must have
at least
5 inputs, when it uses "
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d"
%
len
(
inputs
))
elif
len
(
inputs
)
==
5
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
=
inputs
input_pos_embed
=
None
memory_pos_embed
=
None
elif
len
(
inputs
)
==
6
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
=
inputs
memory_pos_embed
=
None
else
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
=
inputs
[:
7
]
else
:
context_attention_weights
=
None
if
len
(
inputs
)
<
4
:
raise
ValueError
(
"TransformerDecoderBlock must have at leaset 4 inputs, but it "
"got: %d"
%
len
(
inputs
))
elif
len
(
inputs
)
==
4
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
input_pos_embed
=
None
memory_pos_embed
=
None
elif
len
(
inputs
)
==
5
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
=
inputs
memory_pos_embed
=
None
else
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
=
inputs
[:
6
]
return
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
=
self
.
_parse_inputs
(
inputs
,
self
.
multi_channel_cross_attention
)
elif
len
(
inputs
)
!=
4
:
raise
ValueError
(
"TransformerDecoderBlock must have 4 inputs, but it got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
source_tensor
=
input_tensor
if
self
.
_norm_first
:
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
if
input_pos_embed
is
None
:
self_attn_query
=
input_tensor
self_attn_key
=
input_tensor
else
:
self_attn_query
=
input_tensor
+
input_pos_embed
self_attn_key
=
input_tensor
+
input_pos_embed
self_attention_output
,
cache
=
self
.
self_attention
(
query
=
self_attn_query
,
key
=
self_attn_key
,
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
...
...
@@ -438,22 +400,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
source_self_attention_output
=
self_attention_output
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
)
if
input_pos_embed
is
None
:
cross_attn_query
=
self_attention_output
else
:
cross_attn_query
=
self_attention_output
+
input_pos_embed
if
memory_pos_embed
is
None
:
cross_attn_key
=
memory
else
:
cross_attn_key
=
memory
+
memory_pos_embed
cross_attn_inputs
=
dict
(
query
=
cross_attn_query
,
key
=
cross_attn_key
,
query
=
self_attention_output
,
value
=
memory
,
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
[
"context_attention_weights"
]
=
context_attention_weights
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
fa211938
...
...
@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
base_config
=
super
(
TransformerEncoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
,
pos_embed
=
None
):
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
):
"""Return the output of the encoder.
Args:
...
...
@@ -433,17 +433,14 @@ class TransformerEncoder(tf.keras.layers.Layer):
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
pos_embed: A tensor or a float that is added to the query and key of every
self-attention layer. Defaults to None.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for
layer_idx
in
range
(
self
.
num_layers
):
encoder_inputs
=
self
.
encoder_layers
[
layer_idx
](
[
encoder_inputs
,
encoder_inputs
,
attention_mask
,
pos_embed
])
[
encoder_inputs
,
attention_mask
])
output_tensor
=
encoder_inputs
output_tensor
=
self
.
output_normalization
(
output_tensor
)
...
...
@@ -522,7 +519,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
attention_initializer
=
attention_initializer
(
input_shape
[
2
]),
name
=
(
"layer_%d"
%
i
)))
self
.
output_normalization
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
self
.
_norm_epsilon
,
dtype
=
"float32"
)
epsilon
=
1e-6
,
dtype
=
"float32"
)
super
(
TransformerDecoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
...
...
@@ -548,9 +545,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
cross_attention_mask
=
None
,
cache
=
None
,
decode_loop_step
=
None
,
return_all_decoder_outputs
=
False
,
input_pos_embed
=
None
,
memory_pos_embed
=
None
):
return_all_decoder_outputs
=
False
):
"""Return the output of the decoder layer stacks.
Args:
...
...
@@ -570,10 +565,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
input_pos_embed: A tensor or float that is added to the target embedding
in every self-attention and cross-attention layer. Defaults to None.
memory_pos_embed: A tensor or float that is added to the memory embedding
in every cross-attention layer. Defaults to None.
Returns:
Output of decoder.
...
...
@@ -584,8 +575,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
decoder_outputs
=
[]
for
layer_idx
in
range
(
self
.
num_layers
):
transformer_inputs
=
[
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
]
# Gets the cache for decoding.
if
cache
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment