Commit 68d6c14b authored by Krzysztof Choromanski's avatar Krzysztof Choromanski Committed by A. Unique TensorFlower
Browse files

Changing names of the coefficients defining FAVOR++ mechanism.

PiperOrigin-RevId: 460756776
parent 609f332f
...@@ -186,15 +186,15 @@ def expplus(data_orig, ...@@ -186,15 +186,15 @@ def expplus(data_orig,
d_prod *= (2.0 / (l * l)) d_prod *= (2.0 / (l * l))
ave = first_sum_of_squares + second_sum_of_squares + d_prod ave = first_sum_of_squares + second_sum_of_squares + d_prod
dim = projection_matrix.shape[-1] dim = projection_matrix.shape[-1]
A = (1.0 / (4.0 * ave)) * ( a_coeff = (1.0 / (4.0 * ave)) * (
tf.math.sqrt((2.0 * ave + dim) * tf.math.sqrt((2.0 * ave + dim) *
(2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim) (2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim)
A = (1.0 - 1.0 / A) / 8.0 a_coeff = (1.0 - 1.0 / a_coeff) / 8.0
B = tf.math.sqrt(1.0 - 4.0 * A) b_coeff = tf.math.sqrt(1.0 - 4.0 * a_coeff)
D = tf.math.pow(1.0 - 4.0 * A, dim / 4.0) d_coeff = tf.math.pow(1.0 - 4.0 * a_coeff, dim / 4.0)
A = tf.stop_gradient(A) a_coeff = tf.stop_gradient(a_coeff)
B = tf.stop_gradient(B) b_coeff = tf.stop_gradient(b_coeff)
D = tf.stop_gradient(D) d_coeff = tf.stop_gradient(d_coeff)
# Calculating diag_omega for the FAVOR++ mechanism: # Calculating diag_omega for the FAVOR++ mechanism:
diag_omega = tf.math.square(projection_matrix) diag_omega = tf.math.square(projection_matrix)
...@@ -203,24 +203,25 @@ def expplus(data_orig, ...@@ -203,24 +203,25 @@ def expplus(data_orig,
diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = A * diag_omega diag_omega = a_coeff * diag_omega
# #
if numerical_renormalizer: if numerical_renormalizer:
if is_query: if is_query:
last_dims_t = (len(data_dash.shape) - 1,) last_dims_t = (len(data_dash.shape) - 1,)
stab = B * tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True) stab = b_coeff * tf.math.reduce_max(
data_dash, axis=last_dims_t, keepdims=True)
else: else:
stab = B * tf.math.reduce_max(data_dash, keepdims=True) stab = b_coeff * tf.math.reduce_max(data_dash, keepdims=True)
if extra_renormalize_exp_fun: if extra_renormalize_exp_fun:
extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True) extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True)
stab = tf.math.maximum(stab, extra_stab) stab = tf.math.maximum(stab, extra_stab)
data_dash = ratio * D * ( data_dash = ratio * d_coeff * (
tf.math.exp(B * data_dash - stab - diag_data + diag_omega) + tf.math.exp(b_coeff * data_dash - stab - diag_data + diag_omega) +
numerical_stabilizer) numerical_stabilizer)
else: else:
data_dash = ratio * D * ( data_dash = ratio * d_coeff * (
tf.math.exp(B * data_dash - diag_data + diag_omega) + tf.math.exp(b_coeff * data_dash - diag_data + diag_omega) +
numerical_stabilizer) numerical_stabilizer)
return data_dash return data_dash
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment