"vscode:/vscode.git/clone" did not exist on "ebec1c61db24859c415eed66d31827f8fce9e744"
Commit bae94e6d authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[kernel] fix test head shape. This does not cause an error because we overwirte the cache.

PiperOrigin-RevId: 479993154
parent 38c61e26
......@@ -71,7 +71,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
kv_cache = tf.zeros(
(batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, 1, key_dim))
k_sum_cache = tf.zeros((batch_size, num_heads, key_dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment