Unverified Commit 4fea040c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a regression introduced by overlapping KV cache writing (#4375)

parent 6aaeb848
......@@ -326,7 +326,7 @@ class MHATokenToKVPool(KVCache):
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if self.capture_mode:
if self.capture_mode and cache_k.shape[0] < 4:
self.alt_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment