Commit 051f1c96 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[kernel] Improve readability by letting the user of cache to do the decay.

PiperOrigin-RevId: 477359324
parent a4235e26
......@@ -260,10 +260,11 @@ def causal_windowed_performer_attention(query_matrix,
if window_decay is None or window_decay > 1.0 or window_decay < 0.0:
raise ValueError("window_decay should be in (0.0, 1.0) and not None.")
kv = cache["kv"] + tf.einsum("BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv * window_decay
k_sum = cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum * window_decay
kv = window_decay * cache["kv"] + tf.einsum(
"BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv
k_sum = window_decay * cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum
denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum)
attention = tf.einsum("BTHD,BHOD,BTH->BTHO", query_matrix, kv,
1.0 / (denominator + _NUMERIC_STABLER))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment