Unverified Commit 12eb528b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CI ] Remove `past` in favor of `pat_key_values` (#21443)

* fix past renamed to past_key_value

* update more `past`that were ski^êd

* fixup

* remove changes made to rag

* refactor `_reorder_cache` to use `past_key_values`

* fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
parent 5b493762
......@@ -713,9 +713,9 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
......
......@@ -878,7 +878,9 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -886,7 +888,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......
......@@ -1060,7 +1060,9 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
......@@ -1068,7 +1070,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
for layer_past in past_key_values
)
......
......@@ -2508,9 +2508,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......
......@@ -2114,15 +2114,15 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past:
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
......
......@@ -1395,8 +1395,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1521,9 +1521,9 @@ class MarianMTModel(MarianPreTrainedModel):
return logits
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1742,8 +1742,8 @@ class MarianForCausalLM(MarianPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -955,9 +955,9 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
}
# Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1418,9 +1418,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1889,8 +1889,8 @@ class MBartForCausalLM(MBartPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1255,9 +1255,9 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1808,15 +1808,15 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
return self._shift_right(labels)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past:
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
......
......@@ -1579,9 +1579,9 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -2053,8 +2053,8 @@ class MvpForCausalLM(MvpPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -982,9 +982,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
return model_inputs
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1478,9 +1478,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1721,8 +1721,8 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -1681,9 +1681,9 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......
......@@ -1391,9 +1391,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -1739,8 +1739,8 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -2090,9 +2090,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
@staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -2336,9 +2336,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
@staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1156,9 +1156,9 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1205,7 +1205,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
return self.rag.question_encoder
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states, new_order):
......@@ -1216,7 +1216,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
return result
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
......
......@@ -2298,9 +2298,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
return inputs_dict
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reord_past_buckets_states = []
for layer_past in past:
for layer_past in past_key_values:
# buckets
if layer_past[0] is not None:
reord_buckets = layer_past[0].index_select(0, beam_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment