Unverified Commit f4e4b4d0 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: force cache with `inputs_embeds` forwarding (#24639)

parent 9934bb1f
...@@ -1304,6 +1304,11 @@ class GenerationMixin: ...@@ -1304,6 +1304,11 @@ class GenerationMixin:
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment