Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
051f1c96
Commit
051f1c96
authored
Sep 27, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Sep 27, 2022
Browse files
[kernel] Improve readability by letting the user of cache to do the decay.
PiperOrigin-RevId: 477359324
parent
a4235e26
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+5
-4
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
051f1c96
...
...
@@ -260,10 +260,11 @@ def causal_windowed_performer_attention(query_matrix,
if
window_decay
is
None
or
window_decay
>
1.0
or
window_decay
<
0.0
:
raise
ValueError
(
"window_decay should be in (0.0, 1.0) and not None."
)
kv
=
cache
[
"kv"
]
+
tf
.
einsum
(
"BTHD,BTHO->BHOD"
,
key_matrix
,
value_matrix
)
cache
[
"kv"
]
=
kv
*
window_decay
k_sum
=
cache
[
"k_sum"
]
+
tf
.
reduce_sum
(
key_matrix
,
axis
=
1
)
cache
[
"k_sum"
]
=
k_sum
*
window_decay
kv
=
window_decay
*
cache
[
"kv"
]
+
tf
.
einsum
(
"BTHD,BTHO->BHOD"
,
key_matrix
,
value_matrix
)
cache
[
"kv"
]
=
kv
k_sum
=
window_decay
*
cache
[
"k_sum"
]
+
tf
.
reduce_sum
(
key_matrix
,
axis
=
1
)
cache
[
"k_sum"
]
=
k_sum
denominator
=
tf
.
einsum
(
"BTHD,BHD->BTH"
,
query_matrix
,
k_sum
)
attention
=
tf
.
einsum
(
"BTHD,BHOD,BTH->BTHO"
,
query_matrix
,
kv
,
1.0
/
(
denominator
+
_NUMERIC_STABLER
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment