Unverified Commit 9f81f4f6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: force caching on the main model, in assisted generation (#24177)

parent 535f92ae
......@@ -4322,6 +4322,7 @@ class GenerationMixin:
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
......@@ -4330,6 +4331,7 @@ class GenerationMixin:
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
if self.config.is_encoder_decoder:
......@@ -4338,12 +4340,14 @@ class GenerationMixin:
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
# 2.2. Process the new logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment