Commit bae94e6d authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[kernel] fix test head shape. This does not cause an error because we overwirte the cache.

PiperOrigin-RevId: 479993154
parent 38c61e26
...@@ -71,7 +71,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -71,7 +71,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
kv_cache = tf.zeros( kv_cache = tf.zeros(
(batch_size, num_heads, key_dim, (batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim)) num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, 1, key_dim)) k_sum_cache = tf.zeros((batch_size, num_heads, key_dim))
stream_output = [] stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache} cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks): for i in range(num_chunks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment