Unverified Commit 1c68f2ca authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[HybridCache] Fix `get_seq_length` method (#31661)

* fix gemma2

* handle in generate
parent 464aa746
...@@ -1083,7 +1083,7 @@ class HybridCache(Cache): ...@@ -1083,7 +1083,7 @@ class HybridCache(Cache):
# no matter how long the sentence is # no matter how long the sentence is
return self.max_cache_len return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0):
return None return None
def reset(self): def reset(self):
......
...@@ -1399,7 +1399,7 @@ class GenerationMixin: ...@@ -1399,7 +1399,7 @@ class GenerationMixin:
cache = model_kwargs["past_key_values"] cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache): if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2] past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"): elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length() past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs: if "inputs_embeds" in model_kwargs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment