Unverified Commit 08ad34b1 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Fix idefics cache (#31377)

* fix idefics cache

* fix tests
parent a2ede666
...@@ -1593,10 +1593,11 @@ class Idefics2Model(Idefics2PreTrainedModel): ...@@ -1593,10 +1593,11 @@ class Idefics2Model(Idefics2PreTrainedModel):
past_seen_tokens = 0 past_seen_tokens = 0
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache:
if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_usable_length(seq_length) past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
...@@ -1669,7 +1670,7 @@ class Idefics2Model(Idefics2PreTrainedModel): ...@@ -1669,7 +1670,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
if return_legacy_cache: if return_legacy_cache and use_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache() outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
...@@ -1880,8 +1881,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): ...@@ -1880,8 +1881,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() past_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
...@@ -1900,7 +1900,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): ...@@ -1900,7 +1900,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
if ( if (
max_cache_length is not None max_cache_length is not None
and attention_mask is not None and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length and past_length + input_ids.shape[1] > max_cache_length
): ):
attention_mask = attention_mask[:, -max_cache_length:] attention_mask = attention_mask[:, -max_cache_length:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment