Commit 25baa631 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[kernel] Update test to make usage clear

PiperOrigin-RevId: 480718252
parent be205db2
......@@ -68,10 +68,10 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
value=value,
attention_mask=masks,
training=training)
dim = num_random_features if num_random_features > 0 else key_dim
kv_cache = tf.zeros(
(batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, num_heads, key_dim))
(batch_size, num_heads, dim, dim))
k_sum_cache = tf.zeros((batch_size, num_heads, dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment