Unverified Commit c02d302e authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Fix cache type in Idefics2 (#30729)

standardize cache in idefics2
parent 449894d2
...@@ -1591,9 +1591,10 @@ class Idefics2Model(Idefics2PreTrainedModel): ...@@ -1591,9 +1591,10 @@ class Idefics2Model(Idefics2PreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 past_seen_tokens = 0
if use_cache: return_legacy_cache = False
if not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
past_key_values = DynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_usable_length(seq_length) past_seen_tokens = past_key_values.get_usable_length(seq_length)
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
...@@ -1667,6 +1668,9 @@ class Idefics2Model(Idefics2PreTrainedModel): ...@@ -1667,6 +1668,9 @@ class Idefics2Model(Idefics2PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None) return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment