Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fdecf385
Commit
fdecf385
authored
Aug 24, 2022
by
A. Unique TensorFlower
Browse files
Optimized MultiHeadAttention layer to remove an unnecessary transpose
PiperOrigin-RevId: 469827655
parent
bba1dad5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
245 additions
and
0 deletions
+245
-0
official/projects/edgetpu/vision/modeling/optimized_multiheadattention_layer.py
...tpu/vision/modeling/optimized_multiheadattention_layer.py
+164
-0
official/projects/edgetpu/vision/modeling/optimized_multiheadattention_layer_test.py
...ision/modeling/optimized_multiheadattention_layer_test.py
+81
-0
No files found.
official/projects/edgetpu/vision/modeling/optimized_multiheadattention_layer.py
0 → 100644
View file @
fdecf385
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MultiHeadAttention layer optimized for EdgeTPU.
Compared to tf.keras.layers.MultiHeadAttention, this layer performs query-key
multiplication instead of key-query multiplication to remove an unnecessary
transpose.
"""
import
math
import
string
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
tensorflow
as
tf
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_attention_equation
(
rank
:
int
,
attn_axes
:
Tuple
[
int
,
...])
->
Tuple
[
str
,
str
,
int
]:
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
(1) Query-key dot product:
`(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)`
(2) Combination:
`(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch
dims>, <query attention dims>, num_heads, channels)`
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`, that attention will be
applied to.
Returns:
Einsum equations.
"""
target_notation
=
_CHR_IDX
[:
rank
]
# `batch_dims` includes the head dim.
batch_dims
=
tuple
(
np
.
delete
(
range
(
rank
),
attn_axes
+
(
rank
-
1
,)))
letter_offset
=
rank
source_notation
=
""
for
i
in
range
(
rank
):
if
i
in
batch_dims
or
i
==
rank
-
1
:
source_notation
+=
target_notation
[
i
]
else
:
source_notation
+=
_CHR_IDX
[
letter_offset
]
letter_offset
+=
1
product_notation
=
""
.
join
([
target_notation
[
i
]
for
i
in
batch_dims
]
+
[
target_notation
[
i
]
for
i
in
attn_axes
]
+
[
source_notation
[
i
]
for
i
in
attn_axes
])
dot_product_equation
=
"%s,%s->%s"
%
(
target_notation
,
source_notation
,
product_notation
,
)
attn_scores_rank
=
len
(
product_notation
)
combine_equation
=
"%s,%s->%s"
%
(
product_notation
,
source_notation
,
target_notation
,
)
return
dot_product_equation
,
combine_equation
,
attn_scores_rank
class
OptimizedMultiHeadAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""MultiHeadAttention with query-key multiplication.
Currently, this layer only works for self-attention but not for
cross-attention. TODO(b/243166060).
"""
def
_build_attention
(
self
,
rank
:
int
)
->
None
:
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
customize attention computation to replace the default dot-product
attention.
Args:
rank: the rank of query, key, value tensors.
"""
if
self
.
_attention_axes
is
None
:
self
.
_attention_axes
=
tuple
(
range
(
1
,
rank
-
2
))
else
:
self
.
_attention_axes
=
tuple
(
self
.
_attention_axes
)
(
self
.
_dot_product_equation
,
self
.
_combine_equation
,
attn_scores_rank
,
)
=
_build_attention_equation
(
rank
,
attn_axes
=
self
.
_attention_axes
)
norm_axes
=
tuple
(
range
(
attn_scores_rank
-
len
(
self
.
_attention_axes
),
attn_scores_rank
))
self
.
_softmax
=
tf
.
keras
.
layers
.
Softmax
(
axis
=
norm_axes
)
self
.
_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
_compute_attention
(
self
,
query
:
tf
.
Tensor
,
key
:
tf
.
Tensor
,
value
:
tf
.
Tensor
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
training
:
Optional
[
bool
]
=
None
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for
customized attention implementation.
Args:
query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
value: Projected value `Tensor` of shape `(B, S, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions. It is generally not needed if the
`query` and `value` (and/or `key`) are masked.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
query
,
key
)
attention_scores
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout
=
self
.
_dropout_layer
(
attention_scores
,
training
=
training
)
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores_dropout
,
value
)
return
attention_output
,
attention_scores
official/projects/edgetpu/vision/modeling/optimized_multiheadattention_layer_test.py
0 → 100644
View file @
fdecf385
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for optimized_multiheadattention_layer."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.projects.edgetpu.vision.modeling
import
optimized_multiheadattention_layer
_BATCH_SIZE
=
32
_SEQ_LEN
=
4
_EMBEDDING_SIZE
=
8
_NUM_HEADS
=
2
_KEY_DIM
=
2
class
OptimizedMultiheadattentionLayerTest
(
tf
.
test
.
TestCase
):
def
test_same_output
(
self
):
"""Tests that OptimizedMultiHeadAttention returns the expected outputs."""
input_tensor_1
=
tf
.
random
.
uniform
((
_BATCH_SIZE
,
_SEQ_LEN
,
_EMBEDDING_SIZE
))
input_tensor_2
=
tf
.
random
.
uniform
((
_BATCH_SIZE
,
_SEQ_LEN
,
_EMBEDDING_SIZE
))
# Instantiate layer and call with inputs to build.
orig_layer
=
tf
.
keras
.
layers
.
MultiHeadAttention
(
num_heads
=
_NUM_HEADS
,
key_dim
=
_KEY_DIM
)
_
=
orig_layer
(
input_tensor_1
,
input_tensor_2
)
opt_layer
=
optimized_multiheadattention_layer
.
OptimizedMultiHeadAttention
(
num_heads
=
_NUM_HEADS
,
key_dim
=
_KEY_DIM
)
_
=
opt_layer
(
input_tensor_1
,
input_tensor_2
)
# Set the weights of the two layers to be the same.
query_dense_weights
=
np
.
random
.
uniform
(
size
=
(
_EMBEDDING_SIZE
,
_NUM_HEADS
,
_KEY_DIM
))
query_dense_bias
=
np
.
random
.
uniform
(
size
=
(
_NUM_HEADS
,
_KEY_DIM
))
key_dense_weights
=
np
.
random
.
uniform
(
size
=
(
_EMBEDDING_SIZE
,
_NUM_HEADS
,
_KEY_DIM
))
key_dense_bias
=
np
.
random
.
uniform
(
size
=
(
_NUM_HEADS
,
_KEY_DIM
))
value_dense_weights
=
np
.
random
.
uniform
(
size
=
(
_EMBEDDING_SIZE
,
_NUM_HEADS
,
_KEY_DIM
))
value_dense_bias
=
np
.
random
.
uniform
(
size
=
(
_NUM_HEADS
,
_KEY_DIM
))
attention_output_dense_weights
=
np
.
random
.
uniform
(
size
=
(
_NUM_HEADS
,
_KEY_DIM
,
_EMBEDDING_SIZE
))
attention_output_dense_bias
=
np
.
random
.
uniform
(
size
=
(
_EMBEDDING_SIZE
,))
orig_layer
.
_query_dense
.
set_weights
([
query_dense_weights
,
query_dense_bias
])
orig_layer
.
_key_dense
.
set_weights
([
key_dense_weights
,
key_dense_bias
])
orig_layer
.
_value_dense
.
set_weights
([
value_dense_weights
,
value_dense_bias
])
orig_layer
.
_output_dense
.
set_weights
(
[
attention_output_dense_weights
,
attention_output_dense_bias
])
opt_layer
.
_query_dense
.
set_weights
([
query_dense_weights
,
query_dense_bias
])
opt_layer
.
_key_dense
.
set_weights
([
key_dense_weights
,
key_dense_bias
])
opt_layer
.
_value_dense
.
set_weights
([
value_dense_weights
,
value_dense_bias
])
opt_layer
.
_output_dense
.
set_weights
(
[
attention_output_dense_weights
,
attention_output_dense_bias
])
# Calculate two sets of attention outputs and scores and compare.
orig_attn_output
,
orig_attn_score
=
orig_layer
(
input_tensor_1
,
input_tensor_2
,
return_attention_scores
=
True
)
opt_attn_output
,
opt_attn_score
=
opt_layer
(
input_tensor_1
,
input_tensor_2
,
return_attention_scores
=
True
)
self
.
assertAllClose
(
orig_attn_output
,
opt_attn_output
)
self
.
assertAllClose
(
orig_attn_score
,
opt_attn_score
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment