Unverified Commit 39799fbf authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CI-Daily] replace `past` in prepare inputs for generation (#21296)

replace `past` in prepare inputs for generation
parent 23844941
......@@ -658,9 +658,9 @@ class EncoderDecoderModel(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
......@@ -679,6 +679,6 @@ class EncoderDecoderModel(PreTrainedModel):
" model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
return self.decoder._reorder_cache(past_key_values, beam_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment