Unverified Commit 958678df authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix condition when rolling cache (#150)

parent 92a403b2
...@@ -143,7 +143,7 @@ class QuantAttentionFused(nn.Module): ...@@ -143,7 +143,7 @@ class QuantAttentionFused(nn.Module):
will_cache_be_exceeded = self.start_pos + seqlen > self.max_seq_len will_cache_be_exceeded = self.start_pos + seqlen > self.max_seq_len
# Reset and avoid retaining state when processing context # Reset and avoid retaining state when processing context
if will_cache_be_exceeded: if will_cache_be_exceeded and seqlen > 1:
self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=self.start_pos) self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=self.start_pos)
# Slowly roll out old tokens without performance hit if exceeded during decoding # Slowly roll out old tokens without performance hit if exceeded during decoding
elif will_cache_be_exceeded and seqlen == 1: elif will_cache_be_exceeded and seqlen == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment