Commit 68d6c14b authored by Krzysztof Choromanski's avatar Krzysztof Choromanski Committed by A. Unique TensorFlower
Browse files

Changing names of the coefficients defining FAVOR++ mechanism.

PiperOrigin-RevId: 460756776
parent 609f332f
......@@ -186,15 +186,15 @@ def expplus(data_orig,
d_prod *= (2.0 / (l * l))
ave = first_sum_of_squares + second_sum_of_squares + d_prod
dim = projection_matrix.shape[-1]
A = (1.0 / (4.0 * ave)) * (
a_coeff = (1.0 / (4.0 * ave)) * (
tf.math.sqrt((2.0 * ave + dim) *
(2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim)
A = (1.0 - 1.0 / A) / 8.0
B = tf.math.sqrt(1.0 - 4.0 * A)
D = tf.math.pow(1.0 - 4.0 * A, dim / 4.0)
A = tf.stop_gradient(A)
B = tf.stop_gradient(B)
D = tf.stop_gradient(D)
a_coeff = (1.0 - 1.0 / a_coeff) / 8.0
b_coeff = tf.math.sqrt(1.0 - 4.0 * a_coeff)
d_coeff = tf.math.pow(1.0 - 4.0 * a_coeff, dim / 4.0)
a_coeff = tf.stop_gradient(a_coeff)
b_coeff = tf.stop_gradient(b_coeff)
d_coeff = tf.stop_gradient(d_coeff)
# Calculating diag_omega for the FAVOR++ mechanism:
diag_omega = tf.math.square(projection_matrix)
......@@ -203,24 +203,25 @@ def expplus(data_orig,
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = A * diag_omega
diag_omega = a_coeff * diag_omega
#
if numerical_renormalizer:
if is_query:
last_dims_t = (len(data_dash.shape) - 1,)
stab = B * tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)
stab = b_coeff * tf.math.reduce_max(
data_dash, axis=last_dims_t, keepdims=True)
else:
stab = B * tf.math.reduce_max(data_dash, keepdims=True)
stab = b_coeff * tf.math.reduce_max(data_dash, keepdims=True)
if extra_renormalize_exp_fun:
extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True)
stab = tf.math.maximum(stab, extra_stab)
data_dash = ratio * D * (
tf.math.exp(B * data_dash - stab - diag_data + diag_omega) +
data_dash = ratio * d_coeff * (
tf.math.exp(b_coeff * data_dash - stab - diag_data + diag_omega) +
numerical_stabilizer)
else:
data_dash = ratio * D * (
tf.math.exp(B * data_dash - diag_data + diag_omega) +
data_dash = ratio * d_coeff * (
tf.math.exp(b_coeff * data_dash - diag_data + diag_omega) +
numerical_stabilizer)
return data_dash
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment