Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
68d6c14b
Commit
68d6c14b
authored
Jul 13, 2022
by
Krzysztof Choromanski
Committed by
A. Unique TensorFlower
Jul 13, 2022
Browse files
Changing names of the coefficients defining FAVOR++ mechanism.
PiperOrigin-RevId: 460756776
parent
609f332f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
14 deletions
+15
-14
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+15
-14
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
68d6c14b
...
...
@@ -186,15 +186,15 @@ def expplus(data_orig,
d_prod
*=
(
2.0
/
(
l
*
l
))
ave
=
first_sum_of_squares
+
second_sum_of_squares
+
d_prod
dim
=
projection_matrix
.
shape
[
-
1
]
A
=
(
1.0
/
(
4.0
*
ave
))
*
(
a_coeff
=
(
1.0
/
(
4.0
*
ave
))
*
(
tf
.
math
.
sqrt
((
2.0
*
ave
+
dim
)
*
(
2.0
*
ave
+
dim
)
+
8.0
*
dim
*
ave
)
-
2.0
*
ave
-
dim
)
A
=
(
1.0
-
1.0
/
A
)
/
8.0
B
=
tf
.
math
.
sqrt
(
1.0
-
4.0
*
A
)
D
=
tf
.
math
.
pow
(
1.0
-
4.0
*
A
,
dim
/
4.0
)
A
=
tf
.
stop_gradient
(
A
)
B
=
tf
.
stop_gradient
(
B
)
D
=
tf
.
stop_gradient
(
D
)
a_coeff
=
(
1.0
-
1.0
/
a_coeff
)
/
8.0
b_coeff
=
tf
.
math
.
sqrt
(
1.0
-
4.0
*
a_coeff
)
d_coeff
=
tf
.
math
.
pow
(
1.0
-
4.0
*
a_coeff
,
dim
/
4.0
)
a_coeff
=
tf
.
stop_gradient
(
a_coeff
)
b_coeff
=
tf
.
stop_gradient
(
b_coeff
)
d_coeff
=
tf
.
stop_gradient
(
d_coeff
)
# Calculating diag_omega for the FAVOR++ mechanism:
diag_omega
=
tf
.
math
.
square
(
projection_matrix
)
...
...
@@ -203,24 +203,25 @@ def expplus(data_orig,
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
A
*
diag_omega
diag_omega
=
a_coeff
*
diag_omega
#
if
numerical_renormalizer
:
if
is_query
:
last_dims_t
=
(
len
(
data_dash
.
shape
)
-
1
,)
stab
=
B
*
tf
.
math
.
reduce_max
(
data_dash
,
axis
=
last_dims_t
,
keepdims
=
True
)
stab
=
b_coeff
*
tf
.
math
.
reduce_max
(
data_dash
,
axis
=
last_dims_t
,
keepdims
=
True
)
else
:
stab
=
B
*
tf
.
math
.
reduce_max
(
data_dash
,
keepdims
=
True
)
stab
=
b_coeff
*
tf
.
math
.
reduce_max
(
data_dash
,
keepdims
=
True
)
if
extra_renormalize_exp_fun
:
extra_stab
=
tf
.
reduce_max
(
diag_data
,
axis
=
1
,
keepdims
=
True
)
stab
=
tf
.
math
.
maximum
(
stab
,
extra_stab
)
data_dash
=
ratio
*
D
*
(
tf
.
math
.
exp
(
B
*
data_dash
-
stab
-
diag_data
+
diag_omega
)
+
data_dash
=
ratio
*
d_coeff
*
(
tf
.
math
.
exp
(
b_coeff
*
data_dash
-
stab
-
diag_data
+
diag_omega
)
+
numerical_stabilizer
)
else
:
data_dash
=
ratio
*
D
*
(
tf
.
math
.
exp
(
B
*
data_dash
-
diag_data
+
diag_omega
)
+
data_dash
=
ratio
*
d_coeff
*
(
tf
.
math
.
exp
(
b_coeff
*
data_dash
-
diag_data
+
diag_omega
)
+
numerical_stabilizer
)
return
data_dash
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment