Unverified Commit 12eb528b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CI ] Remove `past` in favor of `pat_key_values` (#21443)

* fix past renamed to past_key_value

* update more `past`that were ski^êd

* fixup

* remove changes made to rag

* refactor `_reorder_cache` to use `past_key_values`

* fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
parent 5b493762
......@@ -724,7 +724,7 @@ class GenerationMixin:
return model_kwargs
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
f" enable beam search for {self.__class__}"
......
......@@ -1444,9 +1444,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1921,8 +1921,8 @@ class BartForCausalLM(BartPretrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1288,9 +1288,9 @@ class BertLMHeadModel(BertPreTrainedModel):
"use_cache": use_cache,
}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -2629,9 +2629,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -3098,8 +3098,8 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -714,8 +714,8 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1401,9 +1401,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1624,8 +1624,8 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1367,9 +1367,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1590,8 +1590,8 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -925,8 +925,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1563,9 +1563,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -727,7 +727,9 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -735,5 +737,5 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......@@ -624,7 +624,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -632,7 +634,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......
......@@ -1024,9 +1024,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1671,8 +1671,8 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -690,9 +690,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
......
......@@ -1230,9 +1230,9 @@ class ErnieForCausalLM(ErniePreTrainedModel):
}
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1291,9 +1291,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = []
for layer_past in past:
for layer_past in past_key_values:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
......
......@@ -1512,9 +1512,11 @@ class GitForCausalLM(GitPreTrainedModel):
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **kwargs):
# cut decoder_input_ids if past is used
if past is not None:
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
......@@ -1526,12 +1528,12 @@ class GitForCausalLM(GitPreTrainedModel):
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values", None),
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1117,7 +1117,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -1125,7 +1127,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......@@ -1336,7 +1338,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -1344,7 +1348,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......
......@@ -782,7 +782,9 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -790,7 +792,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......
......@@ -700,9 +700,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
"past_key_values": past_key_values,
}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment