Unverified Commit 9f81f4f6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: force caching on the main model, in assisted generation (#24177)

parent 535f92ae
...@@ -4322,6 +4322,7 @@ class GenerationMixin: ...@@ -4322,6 +4322,7 @@ class GenerationMixin:
encoder_outputs=model_kwargs["encoder_outputs"], encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=True,
) )
else: else:
outputs = self( outputs = self(
...@@ -4330,6 +4331,7 @@ class GenerationMixin: ...@@ -4330,6 +4331,7 @@ class GenerationMixin:
past_key_values=model_kwargs["past_key_values"], past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=True,
) )
else: else:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -4338,12 +4340,14 @@ class GenerationMixin: ...@@ -4338,12 +4340,14 @@ class GenerationMixin:
encoder_outputs=model_kwargs["encoder_outputs"], encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=True,
) )
else: else:
outputs = self( outputs = self(
candidate_input_ids, candidate_input_ids,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=True,
) )
# 2.2. Process the new logits # 2.2. Process the new logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment