Unverified Commit 6a05f68f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix `SlidingWindowCache.reset()` (#31917)

fix sliding cache
parent e3143952
...@@ -971,13 +971,14 @@ class SlidingWindowCache(StaticCache): ...@@ -971,13 +971,14 @@ class SlidingWindowCache(StaticCache):
return k_out, v_out return k_out, v_out
def get_max_length(self) -> Optional[int]: def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is
# no matter how long the sentence is
return None return None
def reset(self): def reset(self):
self.key_cache.zero_() for layer_idx in range(len(self.key_cache)):
self.value_cache.zero_() # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
class EncoderDecoderCache(Cache): class EncoderDecoderCache(Cache):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment