Unverified Commit 12eb528b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CI ] Remove `past` in favor of `pat_key_values` (#21443)

* fix past renamed to past_key_value

* update more `past`that were ski^êd

* fixup

* remove changes made to rag

* refactor `_reorder_cache` to use `past_key_values`

* fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
parent 5b493762
...@@ -724,7 +724,7 @@ class GenerationMixin: ...@@ -724,7 +724,7 @@ class GenerationMixin:
return model_kwargs return model_kwargs
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
raise NotImplementedError( raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
f" enable beam search for {self.__class__}" f" enable beam search for {self.__class__}"
......
...@@ -1444,9 +1444,9 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1444,9 +1444,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
...@@ -1921,8 +1921,8 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1921,8 +1921,8 @@ class BartForCausalLM(BartPretrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -1288,9 +1288,9 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1288,9 +1288,9 @@ class BertLMHeadModel(BertPreTrainedModel):
"use_cache": use_cache, "use_cache": use_cache,
} }
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
......
...@@ -2629,9 +2629,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2629,9 +2629,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
...@@ -3098,8 +3098,8 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): ...@@ -3098,8 +3098,8 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -714,8 +714,8 @@ class BioGptForCausalLM(BioGptPreTrainedModel): ...@@ -714,8 +714,8 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -1401,9 +1401,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1401,9 +1401,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
...@@ -1624,8 +1624,8 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1624,8 +1624,8 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -1367,9 +1367,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1367,9 +1367,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
...@@ -1590,8 +1590,8 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1590,8 +1590,8 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
} }
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -925,8 +925,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): ...@@ -925,8 +925,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
"is_decoder": True, "is_decoder": True,
} }
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -1563,9 +1563,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel): ...@@ -1563,9 +1563,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
......
...@@ -727,7 +727,9 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -727,7 +727,9 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
) )
@staticmethod @staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
...@@ -735,5 +737,5 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -735,5 +737,5 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
""" """
return tuple( return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past_key_values
) )
...@@ -624,7 +624,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -624,7 +624,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
) )
@staticmethod @staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
...@@ -632,7 +634,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -632,7 +634,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
""" """
return tuple( return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past_key_values
) )
......
...@@ -1024,9 +1024,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -1024,9 +1024,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
......
...@@ -1671,8 +1671,8 @@ class ElectraForCausalLM(ElectraPreTrainedModel): ...@@ -1671,8 +1671,8 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -690,9 +690,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -690,9 +690,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values") past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None: if past_key_values is None:
......
...@@ -1230,9 +1230,9 @@ class ErnieForCausalLM(ErniePreTrainedModel): ...@@ -1230,9 +1230,9 @@ class ErnieForCausalLM(ErniePreTrainedModel):
} }
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
......
...@@ -1291,9 +1291,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1291,9 +1291,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
return shift_tokens_right(labels, self.config.pad_token_id) return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = [] reordered_past = []
for layer_past in past: for layer_past in past_key_values:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn # get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = { layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
......
...@@ -1512,9 +1512,11 @@ class GitForCausalLM(GitPreTrainedModel): ...@@ -1512,9 +1512,11 @@ class GitForCausalLM(GitPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **kwargs): def prepare_inputs_for_generation(
# cut decoder_input_ids if past is used self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
if past is not None: ):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
...@@ -1526,12 +1528,12 @@ class GitForCausalLM(GitPreTrainedModel): ...@@ -1526,12 +1528,12 @@ class GitForCausalLM(GitPreTrainedModel):
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values", None), "pixel_values": kwargs.get("pixel_values", None),
"past_key_values": past, "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
} }
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -1117,7 +1117,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1117,7 +1117,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
) )
@staticmethod @staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
...@@ -1125,7 +1127,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1125,7 +1127,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
""" """
return tuple( return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past_key_values
) )
...@@ -1336,7 +1338,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1336,7 +1338,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
) )
@staticmethod @staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
...@@ -1344,7 +1348,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1344,7 +1348,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
""" """
return tuple( return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past_key_values
) )
......
...@@ -782,7 +782,9 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -782,7 +782,9 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
) )
@staticmethod @staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
...@@ -790,7 +792,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -790,7 +792,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
""" """
return tuple( return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past_key_values
) )
......
...@@ -700,9 +700,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -700,9 +700,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
} }
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past_key_values:
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment